aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2016-04-14 16:43:28 -0700
committerJosh Rosen <joshrosen@databricks.com>2016-04-14 16:43:28 -0700
commitee4090b60e8b6a350913d1d5049f0770c251cd4a (patch)
tree7e082fa815430c23e0387461be0726cc3e4d04b5
parent2407f5b14edcdcf750113766d82e78732f9852d6 (diff)
parentd7e124edfe2578ecdf8e816a4dda3ce430a09172 (diff)
downloadspark-ee4090b60e8b6a350913d1d5049f0770c251cd4a.tar.gz
spark-ee4090b60e8b6a350913d1d5049f0770c251cd4a.tar.bz2
spark-ee4090b60e8b6a350913d1d5049f0770c251cd4a.zip
Merge remote-tracking branch 'origin/master' into build-for-2.12
-rw-r--r--LICENSE6
-rw-r--r--R/pkg/NAMESPACE1
-rw-r--r--R/pkg/R/functions.R63
-rw-r--r--R/pkg/R/generics.R4
-rw-r--r--R/pkg/R/mllib.R230
-rw-r--r--R/pkg/inst/tests/testthat/test_context.R2
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R95
-rw-r--r--R/pkg/inst/tests/testthat/test_rdd.R8
-rw-r--r--R/pkg/inst/tests/testthat/test_sparkSQL.R38
-rw-r--r--assembly/pom.xml101
-rwxr-xr-xbin/spark-class13
-rw-r--r--bin/spark-class2.cmd5
-rwxr-xr-xbuild/mvn10
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java32
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java6
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java30
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java2
-rw-r--r--common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java2
-rw-r--r--common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java2
-rw-r--r--common/network-yarn/pom.xml4
-rw-r--r--common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java7
-rw-r--r--common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java60
-rw-r--r--conf/log4j.properties.template4
-rw-r--r--core/src/main/java/org/apache/spark/JavaSparkListener.java88
-rw-r--r--core/src/main/java/org/apache/spark/SparkExecutorInfo.java33
-rw-r--r--core/src/main/java/org/apache/spark/SparkFirehoseListener.java2
-rw-r--r--core/src/main/java/org/apache/spark/api/java/StorageLevels.java6
-rw-r--r--core/src/main/java/org/apache/spark/io/LZ4BlockInputStream.java (renamed from core/src/main/scala/org/apache/spark/io/LZ4BlockInputStream.java)14
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java6
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java7
-rw-r--r--core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java105
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java7
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java7
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java2
-rw-r--r--core/src/main/resources/org/apache/spark/log4j-defaults.properties4
-rw-r--r--core/src/main/resources/org/apache/spark/ui/static/webui.css8
-rw-r--r--core/src/main/scala/org/apache/spark/ContextCleaner.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/FutureAction.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/HttpServer.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/InternalAccumulator.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/SSLOptions.scala57
-rw-r--r--core/src/main/scala/org/apache/spark/SparkConf.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala85
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/SparkStatusTracker.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/StatusAPIImpl.scala33
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContextImpl.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala60
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala80
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/RRDD.scala328
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/RRunner.scala368
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala44
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/Master.scala52
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala6
-rwxr-xr-xcore/src/main/scala/org/apache/spark/deploy/worker/Worker.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala22
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/executor/InputMetrics.scala32
-rw-r--r--core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala30
-rw-r--r--core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala27
-rw-r--r--core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala44
-rw-r--r--core/src/main/scala/org/apache/spark/internal/config/package.scala50
-rw-r--r--core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/memory/MemoryManager.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala22
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/network/BlockTransferService.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala36
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala (renamed from core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala)6
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/JobListener.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala251
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Stage.scala19
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala199
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala39
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala70
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/Serializer.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/status/api/v1/VersionResource.scala30
-rw-r--r--core/src/main/scala/org/apache/spark/status/api/v1/api.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala85
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StorageLevel.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StorageUtils.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala171
-rw-r--r--core/src/main/scala/org/apache/spark/ui/JettyUtils.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/ui/SparkUI.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/ui/WebUI.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala85
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/util/CausedBy.scala (renamed from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala)34
-rw-r--r--core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala38
-rw-r--r--core/src/main/scala/org/apache/spark/util/EventLoop.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/util/JsonProtocol.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/util/SizeEstimator.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala112
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala50
-rw-r--r--core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala (renamed from core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala)38
-rw-r--r--core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala203
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java4
-rw-r--r--core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java64
-rw-r--r--core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json24
-rw-r--r--core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json24
-rw-r--r--core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json4
-rw-r--r--core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json3
-rw-r--r--core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json6
-rw-r--r--core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json18
-rw-r--r--core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json3
-rw-r--r--core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json6
-rw-r--r--core/src/test/resources/HistoryServerExpectations/rdd_list_storage_json_expectation.json10
-rw-r--r--core/src/test/resources/log4j.properties3
-rw-r--r--core/src/test/scala/org/apache/spark/AccumulatorSuite.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/DistributedSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala11
-rw-r--r--core/src/test/scala/org/apache/spark/ShuffleSuite.scala13
-rw-r--r--core/src/test/scala/org/apache/spark/Smuggle.scala46
-rw-r--r--core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/UnpersistSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala19
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala9
-rw-r--r--core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala28
-rw-r--r--core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala27
-rw-r--r--core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala3
-rw-r--r--core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala107
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala15
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala3
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala19
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala24
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala36
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala32
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala26
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala61
-rw-r--r--core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala22
-rw-r--r--core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala67
-rw-r--r--core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala10
-rw-r--r--core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala58
-rw-r--r--core/src/test/scala/org/apache/spark/util/CausedBySuite.scala56
-rw-r--r--core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala72
-rw-r--r--core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala159
-rw-r--r--core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala (renamed from core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala)47
-rw-r--r--core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala197
-rw-r--r--dev/deps/spark-deps-hadoop-2.232
-rw-r--r--dev/deps/spark-deps-hadoop-2.332
-rw-r--r--dev/deps/spark-deps-hadoop-2.432
-rw-r--r--dev/deps/spark-deps-hadoop-2.632
-rw-r--r--dev/deps/spark-deps-hadoop-2.732
-rwxr-xr-xdev/make-distribution.sh25
-rwxr-xr-xdev/mima6
-rwxr-xr-xdev/run-tests.py11
-rw-r--r--dev/sparktestsupport/modules.py14
-rw-r--r--docs/building-spark.md21
-rw-r--r--docs/configuration.md11
-rw-r--r--docs/ml-classification-regression.md34
-rw-r--r--docs/ml-features.md33
-rw-r--r--docs/monitoring.md58
-rw-r--r--docs/running-on-mesos.md6
-rw-r--r--docs/running-on-yarn.md18
-rw-r--r--docs/sql-programming-guide.md48
-rw-r--r--docs/streaming-programming-guide.md2
-rw-r--r--docs/submitting-applications.md3
-rw-r--r--examples/pom.xml80
-rw-r--r--examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java4
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java127
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java2
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java10
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java64
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java87
-rw-r--r--examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java78
-rw-r--r--examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java82
-rw-r--r--examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeansExample.java7
-rw-r--r--examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java77
-rw-r--r--examples/src/main/java/org/apache/spark/examples/mllib/JavaLR.java82
-rw-r--r--examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java14
-rw-r--r--examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java10
-rw-r--r--examples/src/main/java/org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java3
-rw-r--r--examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java4
-rw-r--r--examples/src/main/python/ml/count_vectorizer_example.py44
-rw-r--r--examples/src/main/python/ml/dct_example.py45
-rw-r--r--examples/src/main/python/ml/max_abs_scaler_example.py43
-rw-r--r--examples/src/main/python/ml/min_max_scaler_example.py43
-rw-r--r--examples/src/main/python/ml/naive_bayes_example.py53
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala6
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala6
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala20
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala6
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/LocalALS.scala6
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala2
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala58
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala6
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala6
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala6
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala23
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala21
-rw-r--r--external/docker-integration-tests/pom.xml30
-rw-r--r--external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala157
-rw-r--r--external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala20
-rw-r--r--external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala2
-rw-r--r--external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala2
-rw-r--r--external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala2
-rw-r--r--external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala4
-rw-r--r--external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala12
-rw-r--r--external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala12
-rw-r--r--external/flume-sink/src/test/resources/log4j.properties2
-rw-r--r--external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala21
-rw-r--r--external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala4
-rw-r--r--external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala4
-rw-r--r--external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java11
-rw-r--r--external/flume/src/test/resources/log4j.properties2
-rw-r--r--external/java8-tests/README.md8
-rw-r--r--external/java8-tests/pom.xml87
-rw-r--r--external/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java10
-rw-r--r--external/java8-tests/src/test/java/org/apache/spark/sql/Java8DatasetAggregatorSuite.java61
-rw-r--r--external/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java22
-rw-r--r--external/java8-tests/src/test/resources/log4j.properties3
-rw-r--r--external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala15
-rw-r--r--external/kafka/src/test/resources/log4j.properties2
-rw-r--r--external/kinesis-asl/src/main/resources/log4j.properties4
-rw-r--r--external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala25
-rw-r--r--external/kinesis-asl/src/test/resources/log4j.properties2
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala12
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala2
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala18
-rw-r--r--graphx/src/test/resources/log4j.properties3
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala9
-rw-r--r--launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java71
-rw-r--r--launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java27
-rw-r--r--launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java13
-rw-r--r--launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java2
-rw-r--r--launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java30
-rw-r--r--launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java8
-rw-r--r--launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java6
-rw-r--r--launcher/src/test/resources/log4j.properties3
-rw-r--r--mllib-local/pom.xml74
-rw-r--r--mllib-local/src/main/scala/org/apache/spark/ml/DummyTesting.scala (renamed from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala)13
-rw-r--r--mllib-local/src/test/scala/org/apache/spark/ml/DummyTestingSuite.scala28
-rw-r--r--mllib/pom.xml25
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Estimator.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala14
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Predictor.scala23
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Transformer.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala662
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala124
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala149
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala54
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala89
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala179
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala124
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala67
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala311
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala50
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala135
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala57
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala28
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala124
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala18
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala79
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala85
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala167
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala128
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala155
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala44
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala77
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala121
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala110
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala (renamed from mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala)2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala (renamed from mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala)5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala (renamed from mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala)7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala100
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala114
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala (renamed from mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala)2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala266
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala147
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala169
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala135
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala117
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala117
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala82
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala24
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala98
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala39
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala96
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala47
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala28
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala181
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala195
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala150
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala47
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala73
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java19
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java19
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java4
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java4
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java37
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java4
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java4
-rw-r--r--mllib/src/test/resources/log4j.properties2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala9
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala76
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala15
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala71
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala15
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala59
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala18
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala90
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala48
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala26
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala133
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala22
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala43
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala26
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala24
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala115
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala31
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala69
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala24
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala9
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala19
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala46
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala15
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala (renamed from mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala)2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala85
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala46
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala20
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala55
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala86
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala18
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala31
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala16
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala45
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala13
-rw-r--r--pom.xml159
-rw-r--r--project/MimaExcludes.scala49
-rw-r--r--project/SparkBuild.scala127
-rw-r--r--project/plugins.sbt7
-rw-r--r--python/docs/Makefile8
-rw-r--r--python/pyspark/broadcast.py17
-rw-r--r--python/pyspark/context.py22
-rw-r--r--python/pyspark/join.py2
-rw-r--r--python/pyspark/ml/classification.py284
-rw-r--r--python/pyspark/ml/clustering.py17
-rw-r--r--python/pyspark/ml/evaluation.py4
-rw-r--r--python/pyspark/ml/feature.py58
-rw-r--r--python/pyspark/ml/param/_shared_params_code_gen.py4
-rw-r--r--python/pyspark/ml/param/shared.py24
-rw-r--r--python/pyspark/ml/pipeline.py10
-rw-r--r--python/pyspark/ml/regression.py461
-rw-r--r--python/pyspark/ml/tests.py185
-rw-r--r--python/pyspark/ml/tuning.py409
-rw-r--r--python/pyspark/ml/util.py8
-rw-r--r--python/pyspark/ml/wrapper.py85
-rw-r--r--python/pyspark/mllib/feature.py13
-rw-r--r--python/pyspark/mllib/tests.py16
-rw-r--r--python/pyspark/rdd.py10
-rw-r--r--python/pyspark/sql/context.py2
-rw-r--r--python/pyspark/sql/dataframe.py16
-rw-r--r--python/pyspark/sql/functions.py64
-rw-r--r--python/pyspark/sql/readwriter.py4
-rw-r--r--python/pyspark/sql/tests.py34
-rw-r--r--python/pyspark/sql/utils.py8
-rw-r--r--python/pyspark/storagelevel.py4
-rw-r--r--python/pyspark/streaming/tests.py6
-rw-r--r--python/pyspark/tests.py34
-rw-r--r--python/pyspark/worker.py71
-rwxr-xr-xpython/run-tests.py18
-rw-r--r--repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala21
-rw-r--r--repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala5
-rw-r--r--repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala28
-rw-r--r--repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala17
-rw-r--r--repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala56
-rw-r--r--repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala36
-rw-r--r--repl/src/test/resources/log4j.properties2
-rw-r--r--scalastyle-config.xml12
-rw-r--r--sql/catalyst/pom.xml13
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g400
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g341
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g184
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/KeywordParser.g244
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g235
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g491
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g2596
-rw-r--r--sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4957
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java135
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala231
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala314
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala32
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala58
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala494
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala77
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala31
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala53
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala50
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala321
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala61
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala62
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala53
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala91
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala34
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala168
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala37
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala120
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala38
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala42
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala15
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala82
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala18
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala129
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala86
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala124
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala117
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala16
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala286
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala99
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSparkSQLParser.scala145
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala1455
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala933
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala67
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala245
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala281
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala104
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala50
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala17
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala25
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala111
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala27
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala23
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala6
-rw-r--r--sql/catalyst/src/test/resources/log4j.properties3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala18
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala56
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala61
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala275
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala1
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala111
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala55
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala95
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala76
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala14
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala14
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala16
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala74
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala243
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala54
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala67
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala497
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala65
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala431
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala42
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala126
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala23
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala7
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala12
-rw-r--r--sql/core/pom.xml2
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java (renamed from sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java)4
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java2
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java2
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java120
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java25
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java110
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java26
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java98
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java58
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/expressions/java/typed.java75
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala24
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala109
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala185
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala56
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala133
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala258
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala41
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala124
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala46
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala329
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala792
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala119
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala228
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala62
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala193
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala39
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala101
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala121
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala155
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala46
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/AlterTableCommandParser.scala431
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala176
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala504
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala105
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala118
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala77
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala404
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala44
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala51
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala314
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala91
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala81
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala47
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala73
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala25
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala55
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala88
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala59
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala216
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala65
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala53
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala215
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala57
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala282
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala944
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala74
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala36
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala136
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala87
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala29
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala119
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala72
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala119
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala169
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala27
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala72
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala41
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala80
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala28
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala23
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala40
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala43
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala89
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala346
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala795
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala36
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala84
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala18
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java84
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java134
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java81
-rw-r--r--sql/core/src/test/resources/unescaped-quotes.csv2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala50
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala299
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala152
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala146
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala25
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala40
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala40
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala29
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala108
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala49
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala247
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala29
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala608
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala719
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala110
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala85
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala49
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala13
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala (renamed from sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala)74
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala77
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala78
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala160
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala63
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala29
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala29
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala13
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala104
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala40
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala82
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala71
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala130
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala22
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala3
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala2
-rw-r--r--sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala20
-rw-r--r--sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala98
-rw-r--r--sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala195
-rw-r--r--sql/hive/pom.xml43
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala104
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala (renamed from sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveCatalog.scala)56
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala27
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala310
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala751
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala162
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala30
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala4
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala10
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala4
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala16
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala106
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala6
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala4
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala503
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala31
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala107
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala30
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala24
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala41
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala582
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala (renamed from sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveCatalogSuite.scala)6
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala12
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala231
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala209
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala5
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala167
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala125
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala6
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala351
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala29
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala67
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala10
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala134
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala39
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala34
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala11
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala46
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala7
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala11
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala9
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala3
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala10
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala40
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala4
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala11
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala3
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala8
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala25
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala4
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala10
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala233
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala5
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala23
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala7
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala4
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala29
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala22
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala4
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala12
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala65
-rw-r--r--streaming/src/test/resources/log4j.properties2
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala3
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala11
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala3
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala5
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala10
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala4
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala395
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala23
-rw-r--r--tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala9
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala5
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala18
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala200
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala189
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala4
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala29
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala5
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala18
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala105
-rw-r--r--yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala9
-rw-r--r--yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala42
-rw-r--r--yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala18
-rw-r--r--yarn/src/test/resources/log4j.properties2
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala2
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala31
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala8
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala46
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala24
932 files changed, 33469 insertions, 21788 deletions
diff --git a/LICENSE b/LICENSE
index d7a790a628..9714b3b1e4 100644
--- a/LICENSE
+++ b/LICENSE
@@ -238,6 +238,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt.
(BSD 3 Clause) netlib core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core)
(BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.2.7 - https://github.com/jpmml/jpmml-model)
(BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/)
+ (BSD License) ANTLR 4.5.2-1 (org.antlr:antlr4:4.5.2-1 - http://wwww.antlr.org/)
(BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org)
(BSD licence) ANTLR StringTemplate (org.antlr:stringtemplate:3.2.1 - http://www.stringtemplate.org)
(BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org)
@@ -256,9 +257,8 @@ The text of each license is also included at licenses/LICENSE-[project].txt.
(BSD-style) scalacheck (org.scalacheck:scalacheck_2.11:1.10.0 - http://www.scalacheck.org)
(BSD-style) spire (org.spire-math:spire_2.11:0.7.1 - http://spire-math.org)
(BSD-style) spire-macros (org.spire-math:spire-macros_2.11:0.7.1 - http://spire-math.org)
- (New BSD License) Kryo (com.esotericsoftware.kryo:kryo:2.21 - http://code.google.com/p/kryo/)
- (New BSD License) MinLog (com.esotericsoftware.minlog:minlog:1.2 - http://code.google.com/p/minlog/)
- (New BSD License) ReflectASM (com.esotericsoftware.reflectasm:reflectasm:1.07 - http://code.google.com/p/reflectasm/)
+ (New BSD License) Kryo (com.esotericsoftware:kryo:3.0.3 - https://github.com/EsotericSoftware/kryo)
+ (New BSD License) MinLog (com.esotericsoftware:minlog:1.3.0 - https://github.com/EsotericSoftware/minlog)
(New BSD license) Protocol Buffer Java API (com.google.protobuf:protobuf-java:2.5.0 - http://code.google.com/p/protobuf)
(New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf)
(The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net)
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index fa3fb0b09a..f48c61c1d5 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -265,6 +265,7 @@ exportMethods("%in%",
"var_samp",
"weekofyear",
"when",
+ "window",
"year")
exportClasses("GroupedData")
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index d9c10b4a4b..db877b2d63 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -2131,6 +2131,69 @@ setMethod("from_unixtime", signature(x = "Column"),
column(jc)
})
+#' window
+#'
+#' Bucketize rows into one or more time windows given a timestamp specifying column. Window
+#' starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window
+#' [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in
+#' the order of months are not supported.
+#'
+#' The time column must be of TimestampType.
+#'
+#' Durations are provided as strings, e.g. '1 second', '1 day 12 hours', '2 minutes'. Valid
+#' interval strings are 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'.
+#' If the `slideDuration` is not provided, the windows will be tumbling windows.
+#'
+#' The startTime is the offset with respect to 1970-01-01 00:00:00 UTC with which to start
+#' window intervals. For example, in order to have hourly tumbling windows that start 15 minutes
+#' past the hour, e.g. 12:15-13:15, 13:15-14:15... provide `startTime` as `15 minutes`.
+#'
+#' The output column will be a struct called 'window' by default with the nested columns 'start'
+#' and 'end'.
+#'
+#' @family datetime_funcs
+#' @rdname window
+#' @name window
+#' @export
+#' @examples
+#'\dontrun{
+#' # One minute windows every 15 seconds 10 seconds after the minute, e.g. 09:00:10-09:01:10,
+#' # 09:00:25-09:01:25, 09:00:40-09:01:40, ...
+#' window(df$time, "1 minute", "15 seconds", "10 seconds")
+#'
+#' # One minute tumbling windows 15 seconds after the minute, e.g. 09:00:15-09:01:15,
+#' # 09:01:15-09:02:15...
+#' window(df$time, "1 minute", startTime = "15 seconds")
+#'
+#' # Thirty second windows every 10 seconds, e.g. 09:00:00-09:00:30, 09:00:10-09:00:40, ...
+#' window(df$time, "30 seconds", "10 seconds")
+#'}
+setMethod("window", signature(x = "Column"),
+ function(x, windowDuration, slideDuration = NULL, startTime = NULL) {
+ stopifnot(is.character(windowDuration))
+ if (!is.null(slideDuration) && !is.null(startTime)) {
+ stopifnot(is.character(slideDuration) && is.character(startTime))
+ jc <- callJStatic("org.apache.spark.sql.functions",
+ "window",
+ x@jc, windowDuration, slideDuration, startTime)
+ } else if (!is.null(slideDuration)) {
+ stopifnot(is.character(slideDuration))
+ jc <- callJStatic("org.apache.spark.sql.functions",
+ "window",
+ x@jc, windowDuration, slideDuration)
+ } else if (!is.null(startTime)) {
+ stopifnot(is.character(startTime))
+ jc <- callJStatic("org.apache.spark.sql.functions",
+ "window",
+ x@jc, windowDuration, windowDuration, startTime)
+ } else {
+ jc <- callJStatic("org.apache.spark.sql.functions",
+ "window",
+ x@jc, windowDuration)
+ }
+ column(jc)
+ })
+
#' locate
#'
#' Locate the position of the first occurrence of substr.
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index c6990f4748..ecdeea5ec4 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1152,6 +1152,10 @@ setGeneric("var_samp", function(x) { standardGeneric("var_samp") })
#' @export
setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") })
+#' @rdname window
+#' @export
+setGeneric("window", function(x, ...) { standardGeneric("window") })
+
#' @rdname year
#' @export
setGeneric("year", function(x) { standardGeneric("year") })
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 33654d5216..31bca16580 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -17,10 +17,10 @@
# mllib.R: Provides methods for MLlib integration
-#' @title S4 class that represents a PipelineModel
-#' @param model A Java object reference to the backing Scala PipelineModel
+#' @title S4 class that represents a generalized linear model
+#' @param jobj a Java object reference to the backing Scala GeneralizedLinearRegressionWrapper
#' @export
-setClass("PipelineModel", representation(model = "jobj"))
+setClass("GeneralizedLinearRegressionModel", representation(jobj = "jobj"))
#' @title S4 class that represents a NaiveBayesModel
#' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper
@@ -32,23 +32,25 @@ setClass("NaiveBayesModel", representation(jobj = "jobj"))
#' @export
setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj"))
+#' @title S4 class that represents a KMeansModel
+#' @param jobj a Java object reference to the backing Scala KMeansModel
+#' @export
+setClass("KMeansModel", representation(jobj = "jobj"))
+
#' Fits a generalized linear model
#'
-#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
+#' Fits a generalized linear model, similarly to R's glm().
#'
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', '.', ':', '+', and '-'.
-#' @param data DataFrame for training
-#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg.
-#' @param lambda Regularization parameter
-#' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details)
-#' @param standardize Whether to standardize features before training
-#' @param solver The solver algorithm used for optimization, this can be "l-bfgs", "normal" and
-#' "auto". "l-bfgs" denotes Limited-memory BFGS which is a limited-memory
-#' quasi-Newton optimization method. "normal" denotes using Normal Equation as an
-#' analytical solution to the linear regression problem. The default value is "auto"
-#' which means that the solver algorithm is selected automatically.
-#' @return a fitted MLlib model
+#' @param data DataFrame for training.
+#' @param family A description of the error distribution and link function to be used in the model.
+#' This can be a character string naming a family function, a family function or
+#' the result of a call to a family function. Refer R family at
+#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
+#' @param epsilon Positive convergence tolerance of iterations.
+#' @param maxit Integer giving the maximal number of IRLS iterations.
+#' @return a fitted generalized linear model
#' @rdname glm
#' @export
#' @examples
@@ -59,25 +61,59 @@ setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj"))
#' df <- createDataFrame(sqlContext, iris)
#' model <- glm(Sepal_Length ~ Sepal_Width, df, family="gaussian")
#' summary(model)
-#'}
+#' }
setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"),
- function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0,
- standardize = TRUE, solver = "auto") {
- family <- match.arg(family)
+ function(formula, family = gaussian, data, epsilon = 1e-06, maxit = 25) {
+ if (is.character(family)) {
+ family <- get(family, mode = "function", envir = parent.frame())
+ }
+ if (is.function(family)) {
+ family <- family()
+ }
+ if (is.null(family$family)) {
+ print(family)
+ stop("'family' not recognized")
+ }
+
formula <- paste(deparse(formula), collapse = "")
- model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
- "fitRModelFormula", formula, data@sdf, family, lambda,
- alpha, standardize, solver)
- return(new("PipelineModel", model = model))
+
+ jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
+ "fit", formula, data@sdf, family$family, family$link,
+ epsilon, as.integer(maxit))
+ return(new("GeneralizedLinearRegressionModel", jobj = jobj))
})
-#' Make predictions from a model
+#' Get the summary of a generalized linear model
#'
-#' Makes predictions from a model produced by glm(), similarly to R's predict().
+#' Returns the summary of a model produced by glm(), similarly to R's summary().
#'
-#' @param object A fitted MLlib model
+#' @param object A fitted generalized linear model
+#' @return coefficients the model's coefficients, intercept
+#' @rdname summary
+#' @export
+#' @examples
+#' \dontrun{
+#' model <- glm(y ~ x, trainingData)
+#' summary(model)
+#' }
+setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"),
+ function(object, ...) {
+ jobj <- object@jobj
+ features <- callJMethod(jobj, "rFeatures")
+ coefficients <- callJMethod(jobj, "rCoefficients")
+ coefficients <- as.matrix(unlist(coefficients))
+ colnames(coefficients) <- c("Estimate")
+ rownames(coefficients) <- unlist(features)
+ return(list(coefficients = coefficients))
+ })
+
+#' Make predictions from a generalized linear model
+#'
+#' Makes predictions from a generalized linear model produced by glm(), similarly to R's predict().
+#'
+#' @param object A fitted generalized linear model
#' @param newData DataFrame for testing
-#' @return DataFrame containing predicted values
+#' @return DataFrame containing predicted labels in a column named "prediction"
#' @rdname predict
#' @export
#' @examples
@@ -85,10 +121,10 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram
#' model <- glm(y ~ x, trainingData)
#' predicted <- predict(model, testData)
#' showDF(predicted)
-#'}
-setMethod("predict", signature(object = "PipelineModel"),
+#' }
+setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"),
function(object, newData) {
- return(dataFrame(callJMethod(object@model, "transform", newData@sdf)))
+ return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
})
#' Make predictions from a naive Bayes model
@@ -111,65 +147,6 @@ setMethod("predict", signature(object = "NaiveBayesModel"),
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
})
-#' Get the summary of a model
-#'
-#' Returns the summary of a model produced by glm(), similarly to R's summary().
-#'
-#' @param object A fitted MLlib model
-#' @return a list with 'devianceResiduals' and 'coefficients' components for gaussian family
-#' or a list with 'coefficients' component for binomial family. \cr
-#' For gaussian family: the 'devianceResiduals' gives the min/max deviance residuals
-#' of the estimation, the 'coefficients' gives the estimated coefficients and their
-#' estimated standard errors, t values and p-values. (It only available when model
-#' fitted by normal solver.) \cr
-#' For binomial family: the 'coefficients' gives the estimated coefficients.
-#' See summary.glm for more information. \cr
-#' @rdname summary
-#' @export
-#' @examples
-#' \dontrun{
-#' model <- glm(y ~ x, trainingData)
-#' summary(model)
-#'}
-setMethod("summary", signature(object = "PipelineModel"),
- function(object, ...) {
- modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
- "getModelName", object@model)
- features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
- "getModelFeatures", object@model)
- coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
- "getModelCoefficients", object@model)
- if (modelName == "LinearRegressionModel") {
- devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
- "getModelDevianceResiduals", object@model)
- devianceResiduals <- matrix(devianceResiduals, nrow = 1)
- colnames(devianceResiduals) <- c("Min", "Max")
- rownames(devianceResiduals) <- rep("", times = 1)
- coefficients <- matrix(coefficients, ncol = 4)
- colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)")
- rownames(coefficients) <- unlist(features)
- return(list(devianceResiduals = devianceResiduals, coefficients = coefficients))
- } else if (modelName == "LogisticRegressionModel") {
- coefficients <- as.matrix(unlist(coefficients))
- colnames(coefficients) <- c("Estimate")
- rownames(coefficients) <- unlist(features)
- return(list(coefficients = coefficients))
- } else if (modelName == "KMeansModel") {
- modelSize <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
- "getKMeansModelSize", object@model)
- cluster <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
- "getKMeansCluster", object@model, "classes")
- k <- unlist(modelSize)[1]
- size <- unlist(modelSize)[-1]
- coefficients <- t(matrix(coefficients, ncol = k))
- colnames(coefficients) <- unlist(features)
- rownames(coefficients) <- 1:k
- return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster)))
- } else {
- stop(paste("Unsupported model", modelName, sep = " "))
- }
- })
-
#' Get the summary of a naive Bayes model
#'
#' Returns the summary of a naive Bayes model produced by naiveBayes(), similarly to R's summary().
@@ -213,21 +190,21 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
#' @examples
#' \dontrun{
#' model <- kmeans(x, centers = 2, algorithm="random")
-#'}
+#' }
setMethod("kmeans", signature(x = "DataFrame"),
function(x, centers, iter.max = 10, algorithm = c("random", "k-means||")) {
columnNames <- as.array(colnames(x))
algorithm <- match.arg(algorithm)
- model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "fitKMeans", x@sdf,
- algorithm, iter.max, centers, columnNames)
- return(new("PipelineModel", model = model))
+ jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", x@sdf,
+ centers, iter.max, algorithm, columnNames)
+ return(new("KMeansModel", jobj = jobj))
})
-#' Get fitted result from a model
+#' Get fitted result from a k-means model
#'
-#' Get fitted result from a model, similarly to R's fitted().
+#' Get fitted result from a k-means model, similarly to R's fitted().
#'
-#' @param object A fitted MLlib model
+#' @param object A fitted k-means model
#' @return DataFrame containing fitted values
#' @rdname fitted
#' @export
@@ -237,19 +214,58 @@ setMethod("kmeans", signature(x = "DataFrame"),
#' fitted.model <- fitted(model)
#' showDF(fitted.model)
#'}
-setMethod("fitted", signature(object = "PipelineModel"),
+setMethod("fitted", signature(object = "KMeansModel"),
function(object, method = c("centers", "classes"), ...) {
- modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
- "getModelName", object@model)
+ method <- match.arg(method)
+ return(dataFrame(callJMethod(object@jobj, "fitted", method)))
+ })
- if (modelName == "KMeansModel") {
- method <- match.arg(method)
- fittedResult <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
- "getKMeansCluster", object@model, method)
- return(dataFrame(fittedResult))
- } else {
- stop(paste("Unsupported model", modelName, sep = " "))
- }
+#' Get the summary of a k-means model
+#'
+#' Returns the summary of a k-means model produced by kmeans(),
+#' similarly to R's summary().
+#'
+#' @param object a fitted k-means model
+#' @return the model's coefficients, size and cluster
+#' @rdname summary
+#' @export
+#' @examples
+#' \dontrun{
+#' model <- kmeans(trainingData, 2)
+#' summary(model)
+#' }
+setMethod("summary", signature(object = "KMeansModel"),
+ function(object, ...) {
+ jobj <- object@jobj
+ features <- callJMethod(jobj, "features")
+ coefficients <- callJMethod(jobj, "coefficients")
+ cluster <- callJMethod(jobj, "cluster")
+ k <- callJMethod(jobj, "k")
+ size <- callJMethod(jobj, "size")
+ coefficients <- t(matrix(coefficients, ncol = k))
+ colnames(coefficients) <- unlist(features)
+ rownames(coefficients) <- 1:k
+ return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster)))
+ })
+
+#' Make predictions from a k-means model
+#'
+#' Make predictions from a model produced by kmeans().
+#'
+#' @param object A fitted k-means model
+#' @param newData DataFrame for testing
+#' @return DataFrame containing predicted labels in a column named "prediction"
+#' @rdname predict
+#' @export
+#' @examples
+#' \dontrun{
+#' model <- kmeans(trainingData, 2)
+#' predicted <- predict(model, testData)
+#' showDF(predicted)
+#' }
+setMethod("predict", signature(object = "KMeansModel"),
+ function(object, newData) {
+ return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
})
#' Fit a Bernoulli naive Bayes model
diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R
index ad3f9722a4..6e06c974c2 100644
--- a/R/pkg/inst/tests/testthat/test_context.R
+++ b/R/pkg/inst/tests/testthat/test_context.R
@@ -26,7 +26,7 @@ test_that("Check masked functions", {
maskedBySparkR <- masked[funcSparkROrEmpty]
namesOfMasked <- c("describe", "cov", "filter", "lag", "na.omit", "predict", "sd", "var",
"colnames", "colnames<-", "intersect", "rank", "rbind", "sample", "subset",
- "summary", "transform", "drop")
+ "summary", "transform", "drop", "window")
expect_equal(length(maskedBySparkR), length(namesOfMasked))
expect_equal(sort(maskedBySparkR), sort(namesOfMasked))
# above are those reported as masked when `library(SparkR)`
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index fdb591756e..a9dbd2bdc4 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -25,20 +25,21 @@ sc <- sparkR.init()
sqlContext <- sparkRSQL.init(sc)
-test_that("glm and predict", {
+test_that("formula of glm", {
training <- suppressWarnings(createDataFrame(sqlContext, iris))
- test <- select(training, "Sepal_Length")
- model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian")
- prediction <- predict(model, test)
- expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
+ # dot minus and intercept vs native glm
+ model <- glm(Sepal_Width ~ . - Species + 0, data = training)
+ vals <- collect(select(predict(model, training), "prediction"))
+ rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris)
+ expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
- # Test stats::predict is working
- x <- rnorm(15)
- y <- x + rnorm(15)
- expect_equal(length(predict(lm(y ~ x))), 15)
-})
+ # feature interaction vs native glm
+ model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training)
+ vals <- collect(select(predict(model, training), "prediction"))
+ rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris)
+ expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
-test_that("glm should work with long formula", {
+ # glm should work with long formula
training <- suppressWarnings(createDataFrame(sqlContext, iris))
training$LongLongLongLongLongName <- training$Sepal_Width
training$VeryLongLongLongLonLongName <- training$Sepal_Length
@@ -50,68 +51,30 @@ test_that("glm should work with long formula", {
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
})
-test_that("predictions match with native glm", {
+test_that("glm and predict", {
training <- suppressWarnings(createDataFrame(sqlContext, iris))
+ # gaussian family
model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training)
- vals <- collect(select(predict(model, training), "prediction"))
+ prediction <- predict(model, training)
+ expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
+ vals <- collect(select(prediction, "prediction"))
rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris)
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
-})
-
-test_that("dot minus and intercept vs native glm", {
- training <- suppressWarnings(createDataFrame(sqlContext, iris))
- model <- glm(Sepal_Width ~ . - Species + 0, data = training)
- vals <- collect(select(predict(model, training), "prediction"))
- rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris)
- expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
-})
-test_that("feature interaction vs native glm", {
- training <- suppressWarnings(createDataFrame(sqlContext, iris))
- model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training)
- vals <- collect(select(predict(model, training), "prediction"))
- rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris)
+ # poisson family
+ model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training,
+ family = poisson(link = identity))
+ prediction <- predict(model, training)
+ expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double")
+ vals <- collect(select(prediction, "prediction"))
+ rVals <- suppressWarnings(predict(glm(Sepal.Width ~ Sepal.Length + Species,
+ data = iris, family = poisson(link = identity)), iris))
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
-})
-test_that("summary coefficients match with native glm", {
- training <- suppressWarnings(createDataFrame(sqlContext, iris))
- stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "normal"))
- coefs <- unlist(stats$coefficients)
- devianceResiduals <- unlist(stats$devianceResiduals)
-
- rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))
- rCoefs <- unlist(rStats$coefficients)
- rDevianceResiduals <- c(-0.95096, 0.72918)
-
- expect_true(all(abs(rCoefs - coefs) < 1e-5))
- expect_true(all(abs(rDevianceResiduals - devianceResiduals) < 1e-5))
- expect_true(all(
- rownames(stats$coefficients) ==
- c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica")))
-})
-
-test_that("summary coefficients match with native glm of family 'binomial'", {
- df <- suppressWarnings(createDataFrame(sqlContext, iris))
- training <- filter(df, df$Species != "setosa")
- stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training,
- family = "binomial"))
- coefs <- as.vector(stats$coefficients[, 1])
-
- rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ]
- rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
- family = binomial(link = "logit"))))
-
- expect_true(all(abs(rCoefs - coefs) < 1e-4))
- expect_true(all(
- rownames(stats$coefficients) ==
- c("(Intercept)", "Sepal_Length", "Sepal_Width")))
-})
-
-test_that("summary works on base GLM models", {
- baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris)
- baseSummary <- summary(baseModel)
- expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)
+ # Test stats::predict is working
+ x <- rnorm(15)
+ y <- x + rnorm(15)
+ expect_equal(length(predict(lm(y ~ x))), 15)
})
test_that("kmeans", {
diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R
index 3b0c16be5a..b6c8e1dc6c 100644
--- a/R/pkg/inst/tests/testthat/test_rdd.R
+++ b/R/pkg/inst/tests/testthat/test_rdd.R
@@ -791,3 +791,11 @@ test_that("sampleByKey() on pairwise RDDs", {
expect_equal(lookup(sample, 3)[which.min(lookup(sample, 3))] >= 0, TRUE)
expect_equal(lookup(sample, 3)[which.max(lookup(sample, 3))] <= 2000, TRUE)
})
+
+test_that("Test correct concurrency of RRDD.compute()", {
+ rdd <- parallelize(sc, 1:1000, 100)
+ jrdd <- getJRDD(lapply(rdd, function(x) { x }), "row")
+ zrdd <- callJMethod(jrdd, "zip", jrdd)
+ count <- callJMethod(zrdd, "count")
+ expect_equal(count, 1000)
+})
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index eef365b42e..d747d4f83f 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -1204,6 +1204,42 @@ test_that("greatest() and least() on a DataFrame", {
expect_equal(collect(select(df, least(df$a, df$b)))[, 1], c(1, 3))
})
+test_that("time windowing (window()) with all inputs", {
+ df <- createDataFrame(sqlContext, data.frame(t = c("2016-03-11 09:00:07"), v = c(1)))
+ df$window <- window(df$t, "5 seconds", "5 seconds", "0 seconds")
+ local <- collect(df)$v
+ # Not checking time windows because of possible time zone issues. Just checking that the function
+ # works
+ expect_equal(local, c(1))
+})
+
+test_that("time windowing (window()) with slide duration", {
+ df <- createDataFrame(sqlContext, data.frame(t = c("2016-03-11 09:00:07"), v = c(1)))
+ df$window <- window(df$t, "5 seconds", "2 seconds")
+ local <- collect(df)$v
+ # Not checking time windows because of possible time zone issues. Just checking that the function
+ # works
+ expect_equal(local, c(1, 1))
+})
+
+test_that("time windowing (window()) with start time", {
+ df <- createDataFrame(sqlContext, data.frame(t = c("2016-03-11 09:00:07"), v = c(1)))
+ df$window <- window(df$t, "5 seconds", startTime = "2 seconds")
+ local <- collect(df)$v
+ # Not checking time windows because of possible time zone issues. Just checking that the function
+ # works
+ expect_equal(local, c(1))
+})
+
+test_that("time windowing (window()) with just window duration", {
+ df <- createDataFrame(sqlContext, data.frame(t = c("2016-03-11 09:00:07"), v = c(1)))
+ df$window <- window(df$t, "5 seconds")
+ local <- collect(df)$v
+ # Not checking time windows because of possible time zone issues. Just checking that the function
+ # works
+ expect_equal(local, c(1))
+})
+
test_that("when(), otherwise() and ifelse() on a DataFrame", {
l <- list(list(a = 1, b = 2), list(a = 3, b = 4))
df <- createDataFrame(sqlContext, l)
@@ -1817,7 +1853,7 @@ test_that("approxQuantile() on a DataFrame", {
test_that("SQL error message is returned from JVM", {
retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e)
- expect_equal(grepl("Table not found", retError), TRUE)
+ expect_equal(grepl("Table or View not found", retError), TRUE)
expect_equal(grepl("blah", retError), TRUE)
})
diff --git a/assembly/pom.xml b/assembly/pom.xml
index 477d4931c3..22cbac06ca 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -33,9 +33,8 @@
<properties>
<sbt.project.name>assembly</sbt.project.name>
- <spark.jar.dir>scala-${scala.binary.version}</spark.jar.dir>
- <spark.jar.basename>spark-assembly-${project.version}-hadoop${hadoop.version}.jar</spark.jar.basename>
- <spark.jar>${project.build.directory}/${spark.jar.dir}/${spark.jar.basename}</spark.jar>
+ <build.testJarPhase>none</build.testJarPhase>
+ <build.copyDependenciesPhase>package</build.copyDependenciesPhase>
</properties>
<dependencies>
@@ -69,6 +68,17 @@
<artifactId>spark-repl_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
+
+ <!--
+ Because we don't shade dependencies anymore, we need to restore Guava to compile scope so
+ that the libraries Spark depend on have it available. We'll package the version that Spark
+ uses (14.0.1) which is not the same as Hadoop dependencies, but works.
+ -->
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ <scope>${hadoop.deps.scope}</scope>
+ </dependency>
</dependencies>
<build>
@@ -87,75 +97,26 @@
<skip>true</skip>
</configuration>
</plugin>
- <!-- zip pyspark archives to run python application on yarn mode -->
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-antrun-plugin</artifactId>
- <executions>
- <execution>
- <phase>package</phase>
- <goals>
- <goal>run</goal>
- </goals>
- </execution>
- </executions>
- <configuration>
- <target>
- <delete dir="${basedir}/../python/lib/pyspark.zip"/>
- <zip destfile="${basedir}/../python/lib/pyspark.zip">
- <fileset dir="${basedir}/../python/" includes="pyspark/**/*"/>
- </zip>
- </target>
- </configuration>
- </plugin>
- <!-- Use the shade plugin to create a big JAR with all the dependencies -->
+ <!-- zip pyspark archives to run python application on yarn mode -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-shade-plugin</artifactId>
- <configuration>
- <shadedArtifactAttached>false</shadedArtifactAttached>
- <outputFile>${spark.jar}</outputFile>
- <artifactSet>
- <includes>
- <include>*:*</include>
- </includes>
- </artifactSet>
- <filters>
- <filter>
- <artifact>*:*</artifact>
- <excludes>
- <exclude>org/datanucleus/**</exclude>
- <exclude>META-INF/*.SF</exclude>
- <exclude>META-INF/*.DSA</exclude>
- <exclude>META-INF/*.RSA</exclude>
- </excludes>
- </filter>
- </filters>
- </configuration>
- <executions>
- <execution>
- <phase>package</phase>
- <goals>
- <goal>shade</goal>
- </goals>
- <configuration>
- <transformers>
- <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer" />
- <transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
- <resource>META-INF/services/org.apache.hadoop.fs.FileSystem</resource>
- </transformer>
- <transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
- <resource>reference.conf</resource>
- </transformer>
- <transformer implementation="org.apache.maven.plugins.shade.resource.DontIncludeResourceTransformer">
- <resource>log4j.properties</resource>
- </transformer>
- <transformer implementation="org.apache.maven.plugins.shade.resource.ApacheLicenseResourceTransformer"/>
- <transformer implementation="org.apache.maven.plugins.shade.resource.ApacheNoticeResourceTransformer"/>
- </transformers>
- </configuration>
- </execution>
- </executions>
+ <artifactId>maven-antrun-plugin</artifactId>
+ <executions>
+ <execution>
+ <phase>package</phase>
+ <goals>
+ <goal>run</goal>
+ </goals>
+ </execution>
+ </executions>
+ <configuration>
+ <target>
+ <delete dir="${basedir}/../python/lib/pyspark.zip"/>
+ <zip destfile="${basedir}/../python/lib/pyspark.zip">
+ <fileset dir="${basedir}/../python/" includes="pyspark/**/*"/>
+ </zip>
+ </target>
+ </configuration>
</plugin>
</plugins>
</build>
diff --git a/bin/spark-class b/bin/spark-class
index e710e388be..b2a36b9846 100755
--- a/bin/spark-class
+++ b/bin/spark-class
@@ -36,21 +36,20 @@ else
fi
# Find Spark jars.
-# TODO: change the directory name when Spark jars move from "lib".
if [ -f "${SPARK_HOME}/RELEASE" ]; then
- SPARK_JARS_DIR="${SPARK_HOME}/lib"
+ SPARK_JARS_DIR="${SPARK_HOME}/jars"
else
- SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION"
+ SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars"
fi
-if [ ! -d "$SPARK_JARS_DIR" ]; then
+if [ ! -d "$SPARK_JARS_DIR" ] && [ -z "$SPARK_TESTING$SPARK_SQL_TESTING" ]; then
echo "Failed to find Spark jars directory ($SPARK_JARS_DIR)." 1>&2
- echo "You need to build Spark before running this program." 1>&2
+ echo "You need to build Spark with the target \"package\" before running this program." 1>&2
exit 1
+else
+ LAUNCH_CLASSPATH="$SPARK_JARS_DIR/*"
fi
-LAUNCH_CLASSPATH="$SPARK_JARS_DIR/*"
-
# Add the launcher build dir to the classpath if requested.
if [ -n "$SPARK_PREPEND_CLASSES" ]; then
LAUNCH_CLASSPATH="${SPARK_HOME}/launcher/target/scala-$SPARK_SCALA_VERSION/classes:$LAUNCH_CLASSPATH"
diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd
index 565b87c102..579efff909 100644
--- a/bin/spark-class2.cmd
+++ b/bin/spark-class2.cmd
@@ -29,11 +29,10 @@ if "x%1"=="x" (
)
rem Find Spark jars.
-rem TODO: change the directory name when Spark jars move from "lib".
if exist "%SPARK_HOME%\RELEASE" (
- set SPARK_JARS_DIR="%SPARK_HOME%\lib"
+ set SPARK_JARS_DIR="%SPARK_HOME%\jars"
) else (
- set SPARK_JARS_DIR="%SPARK_HOME%\assembly\target\scala-%SPARK_SCALA_VERSION%"
+ set SPARK_JARS_DIR="%SPARK_HOME%\assembly\target\scala-%SPARK_SCALA_VERSION%\jars"
)
if not exist "%SPARK_JARS_DIR%"\ (
diff --git a/build/mvn b/build/mvn
index 58058c04b8..eb42552fc4 100755
--- a/build/mvn
+++ b/build/mvn
@@ -70,9 +70,10 @@ install_app() {
# Install maven under the build/ folder
install_mvn() {
local MVN_VERSION="3.3.9"
+ local APACHE_MIRROR=${APACHE_MIRROR:-'https://www.apache.org/dyn/closer.lua?action=download&filename='}
install_app \
- "http://archive.apache.org/dist/maven/maven-3/${MVN_VERSION}/binaries" \
+ "${APACHE_MIRROR}/maven/maven-3/${MVN_VERSION}/binaries" \
"apache-maven-${MVN_VERSION}-bin.tar.gz" \
"apache-maven-${MVN_VERSION}/bin/mvn"
@@ -83,8 +84,10 @@ install_mvn() {
install_zinc() {
local zinc_path="zinc-0.3.9/bin/zinc"
[ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1
+ local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.typesafe.com}
+
install_app \
- "http://downloads.typesafe.com/zinc/0.3.9" \
+ "${TYPESAFE_MIRROR}/zinc/0.3.9" \
"zinc-0.3.9.tgz" \
"${zinc_path}"
ZINC_BIN="${_DIR}/${zinc_path}"
@@ -98,9 +101,10 @@ install_scala() {
local scala_version=`grep "scala.version" "${_DIR}/../pom.xml" | \
head -1 | cut -f2 -d'>' | cut -f1 -d'<'`
local scala_bin="${_DIR}/scala-${scala_version}/bin/scala"
+ local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.typesafe.com}
install_app \
- "http://downloads.typesafe.com/scala/${scala_version}" \
+ "${TYPESAFE_MIRROR}/scala/${scala_version}" \
"scala-${scala_version}.tgz" \
"scala-${scala_version}/bin/scala"
diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
index f179bad1f4..a27aaf2b27 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -94,7 +94,7 @@ public class TransportClientFactory implements Closeable {
this.context = Preconditions.checkNotNull(context);
this.conf = context.getConf();
this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps));
- this.connectionPool = new ConcurrentHashMap<SocketAddress, ClientPool>();
+ this.connectionPool = new ConcurrentHashMap<>();
this.numConnectionsPerPeer = conf.numConnectionsPerPeer();
this.rand = new Random();
@@ -123,13 +123,15 @@ public class TransportClientFactory implements Closeable {
public TransportClient createClient(String remoteHost, int remotePort) throws IOException {
// Get connection from the connection pool first.
// If it is not found or not active, create a new one.
- final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
+ // Use unresolved address here to avoid DNS resolution each time we creates a client.
+ final InetSocketAddress unresolvedAddress =
+ InetSocketAddress.createUnresolved(remoteHost, remotePort);
// Create the ClientPool if we don't have it yet.
- ClientPool clientPool = connectionPool.get(address);
+ ClientPool clientPool = connectionPool.get(unresolvedAddress);
if (clientPool == null) {
- connectionPool.putIfAbsent(address, new ClientPool(numConnectionsPerPeer));
- clientPool = connectionPool.get(address);
+ connectionPool.putIfAbsent(unresolvedAddress, new ClientPool(numConnectionsPerPeer));
+ clientPool = connectionPool.get(unresolvedAddress);
}
int clientIndex = rand.nextInt(numConnectionsPerPeer);
@@ -146,25 +148,35 @@ public class TransportClientFactory implements Closeable {
}
if (cachedClient.isActive()) {
- logger.trace("Returning cached connection to {}: {}", address, cachedClient);
+ logger.trace("Returning cached connection to {}: {}",
+ cachedClient.getSocketAddress(), cachedClient);
return cachedClient;
}
}
// If we reach here, we don't have an existing connection open. Let's create a new one.
// Multiple threads might race here to create new connections. Keep only one of them active.
+ final long preResolveHost = System.nanoTime();
+ final InetSocketAddress resolvedAddress = new InetSocketAddress(remoteHost, remotePort);
+ final long hostResolveTimeMs = (System.nanoTime() - preResolveHost) / 1000000;
+ if (hostResolveTimeMs > 2000) {
+ logger.warn("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs);
+ } else {
+ logger.trace("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs);
+ }
+
synchronized (clientPool.locks[clientIndex]) {
cachedClient = clientPool.clients[clientIndex];
if (cachedClient != null) {
if (cachedClient.isActive()) {
- logger.trace("Returning cached connection to {}: {}", address, cachedClient);
+ logger.trace("Returning cached connection to {}: {}", resolvedAddress, cachedClient);
return cachedClient;
} else {
- logger.info("Found inactive connection to {}, creating a new one.", address);
+ logger.info("Found inactive connection to {}, creating a new one.", resolvedAddress);
}
}
- clientPool.clients[clientIndex] = createClient(address);
+ clientPool.clients[clientIndex] = createClient(resolvedAddress);
return clientPool.clients[clientIndex];
}
}
@@ -235,7 +247,7 @@ public class TransportClientFactory implements Closeable {
}
long postBootstrap = System.nanoTime();
- logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
+ logger.info("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
address, (postBootstrap - preConnect) / 1000000, (postBootstrap - preBootstrap) / 1000000);
return client;
diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
index f0e2004d2d..8a69223c88 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
@@ -64,9 +64,9 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
public TransportResponseHandler(Channel channel) {
this.channel = channel;
- this.outstandingFetches = new ConcurrentHashMap<StreamChunkId, ChunkReceivedCallback>();
- this.outstandingRpcs = new ConcurrentHashMap<Long, RpcResponseCallback>();
- this.streamCallbacks = new ConcurrentLinkedQueue<StreamCallback>();
+ this.outstandingFetches = new ConcurrentHashMap<>();
+ this.outstandingRpcs = new ConcurrentHashMap<>();
+ this.streamCallbacks = new ConcurrentLinkedQueue<>();
this.timeOfLastRequestNs = new AtomicLong(0);
}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java
index 66227f96a1..4f8781b42a 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java
@@ -18,6 +18,7 @@
package org.apache.spark.network.protocol;
import java.io.IOException;
+import java.nio.ByteBuffer;
import java.nio.channels.WritableByteChannel;
import javax.annotation.Nullable;
@@ -44,6 +45,14 @@ class MessageWithHeader extends AbstractReferenceCounted implements FileRegion {
private long totalBytesTransferred;
/**
+ * When the write buffer size is larger than this limit, I/O will be done in chunks of this size.
+ * The size should not be too large as it will waste underlying memory copy. e.g. If network
+ * avaliable buffer is smaller than this limit, the data cannot be sent within one single write
+ * operation while it still will make memory copy with this size.
+ */
+ private static final int NIO_BUFFER_LIMIT = 256 * 1024;
+
+ /**
* Construct a new MessageWithHeader.
*
* @param managedBuffer the {@link ManagedBuffer} that the message body came from. This needs to
@@ -128,8 +137,27 @@ class MessageWithHeader extends AbstractReferenceCounted implements FileRegion {
}
private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException {
- int written = target.write(buf.nioBuffer());
+ ByteBuffer buffer = buf.nioBuffer();
+ int written = (buffer.remaining() <= NIO_BUFFER_LIMIT) ?
+ target.write(buffer) : writeNioBuffer(target, buffer);
buf.skipBytes(written);
return written;
}
+
+ private int writeNioBuffer(
+ WritableByteChannel writeCh,
+ ByteBuffer buf) throws IOException {
+ int originalLimit = buf.limit();
+ int ret = 0;
+
+ try {
+ int ioSize = Math.min(buf.remaining(), NIO_BUFFER_LIMIT);
+ buf.limit(buf.position() + ioSize);
+ ret = writeCh.write(buf);
+ } finally {
+ buf.limit(originalLimit);
+ }
+
+ return ret;
+ }
}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
index e2222ae085..ae7e520b2f 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java
@@ -63,7 +63,7 @@ public class OneForOneStreamManager extends StreamManager {
// For debugging purposes, start with a random stream id to help identifying different streams.
// This does not need to be globally unique, only unique to this class.
nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000);
- streams = new ConcurrentHashMap<Long, StreamState>();
+ streams = new ConcurrentHashMap<>();
}
@Override
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
index bd1830e6ab..fcec7dfd0c 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
@@ -140,7 +140,7 @@ public class TransportFrameDecoder extends ChannelInboundHandlerAdapter {
}
// Otherwise, create a composite buffer.
- CompositeByteBuf frame = buffers.getFirst().alloc().compositeBuffer();
+ CompositeByteBuf frame = buffers.getFirst().alloc().compositeBuffer(Integer.MAX_VALUE);
while (remaining > 0) {
ByteBuf next = nextBufferForFrame(remaining);
remaining -= next.readableBytes();
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java
index 268cb40121..56a025c4d9 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java
@@ -37,7 +37,7 @@ public class ShuffleSecretManager implements SecretKeyHolder {
private static final String SPARK_SASL_USER = "sparkSaslUser";
public ShuffleSecretManager() {
- shuffleSecretMap = new ConcurrentHashMap<String, String>();
+ shuffleSecretMap = new ConcurrentHashMap<>();
}
/**
diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml
index 3cb44324f2..bc83ef24c3 100644
--- a/common/network-yarn/pom.xml
+++ b/common/network-yarn/pom.xml
@@ -36,7 +36,7 @@
<!-- Make sure all Hadoop dependencies are provided to avoid repackaging. -->
<hadoop.deps.scope>provided</hadoop.deps.scope>
<shuffle.jar>${project.build.directory}/scala-${scala.binary.version}/spark-${project.version}-yarn-shuffle.jar</shuffle.jar>
- <shade>org/spark-project/</shade>
+ <shade>org/spark_project/</shade>
</properties>
<dependencies>
@@ -91,7 +91,7 @@
<relocations>
<relocation>
<pattern>com.fasterxml.jackson</pattern>
- <shadedPattern>org.spark-project.com.fasterxml.jackson</shadedPattern>
+ <shadedPattern>${spark.shade.packageName}.com.fasterxml.jackson</shadedPattern>
<includes>
<include>com.fasterxml.jackson.**</include>
</includes>
diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
index ba6d30a74c..4bc3c1a3c8 100644
--- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
+++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
@@ -24,6 +24,7 @@ import java.util.List;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.server.api.*;
import org.slf4j.Logger;
@@ -118,7 +119,7 @@ public class YarnShuffleService extends AuxiliaryService {
// an application was stopped while the NM was down, we expect yarn to call stopApplication()
// when it comes back
registeredExecutorFile =
- findRegisteredExecutorFile(conf.getStrings("yarn.nodemanager.local-dirs"));
+ findRegisteredExecutorFile(conf.getTrimmedStrings("yarn.nodemanager.local-dirs"));
TransportConf transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf));
// If authentication is enabled, set up the shuffle server to use a
@@ -191,12 +192,12 @@ public class YarnShuffleService extends AuxiliaryService {
private File findRegisteredExecutorFile(String[] localDirs) {
for (String dir: localDirs) {
- File f = new File(dir, "registeredExecutors.ldb");
+ File f = new File(new Path(dir).toUri().getPath(), "registeredExecutors.ldb");
if (f.exists()) {
return f;
}
}
- return new File(localDirs[0], "registeredExecutors.ldb");
+ return new File(new Path(localDirs[0]).toUri().getPath(), "registeredExecutors.ldb");
}
/**
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
index 18761bfd22..bdf52f32c6 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
@@ -17,8 +17,12 @@
package org.apache.spark.unsafe;
+import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
+import java.lang.reflect.Method;
+import java.nio.ByteBuffer;
+import sun.misc.Cleaner;
import sun.misc.Unsafe;
public final class Platform {
@@ -37,6 +41,33 @@ public final class Platform {
public static final int DOUBLE_ARRAY_OFFSET;
+ private static final boolean unaligned;
+ static {
+ boolean _unaligned;
+ // use reflection to access unaligned field
+ try {
+ Class<?> bitsClass =
+ Class.forName("java.nio.Bits", false, ClassLoader.getSystemClassLoader());
+ Method unalignedMethod = bitsClass.getDeclaredMethod("unaligned");
+ unalignedMethod.setAccessible(true);
+ _unaligned = Boolean.TRUE.equals(unalignedMethod.invoke(null));
+ } catch (Throwable t) {
+ // We at least know x86 and x64 support unaligned access.
+ String arch = System.getProperty("os.arch", "");
+ //noinspection DynamicRegexReplaceableByCompiledPattern
+ _unaligned = arch.matches("^(i[3-6]86|x86(_64)?|x64|amd64)$");
+ }
+ unaligned = _unaligned;
+ }
+
+ /**
+ * @return true when running JVM is having sun's Unsafe package available in it and underlying
+ * system having unaligned-access capability.
+ */
+ public static boolean unaligned() {
+ return unaligned;
+ }
+
public static int getInt(Object object, long offset) {
return _UNSAFE.getInt(object, offset);
}
@@ -116,6 +147,35 @@ public final class Platform {
return newMemory;
}
+ /**
+ * Uses internal JDK APIs to allocate a DirectByteBuffer while ignoring the JVM's
+ * MaxDirectMemorySize limit (the default limit is too low and we do not want to require users
+ * to increase it).
+ */
+ @SuppressWarnings("unchecked")
+ public static ByteBuffer allocateDirectBuffer(int size) {
+ try {
+ Class cls = Class.forName("java.nio.DirectByteBuffer");
+ Constructor constructor = cls.getDeclaredConstructor(Long.TYPE, Integer.TYPE);
+ constructor.setAccessible(true);
+ Field cleanerField = cls.getDeclaredField("cleaner");
+ cleanerField.setAccessible(true);
+ final long memory = allocateMemory(size);
+ ByteBuffer buffer = (ByteBuffer) constructor.newInstance(memory, size);
+ Cleaner cleaner = Cleaner.create(buffer, new Runnable() {
+ @Override
+ public void run() {
+ freeMemory(memory);
+ }
+ });
+ cleanerField.set(buffer, cleaner);
+ return buffer;
+ } catch (Exception e) {
+ throwException(e);
+ }
+ throw new IllegalStateException("unreachable");
+ }
+
public static void setMemory(long address, byte value, long size) {
_UNSAFE.setMemory(address, size, value);
}
diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template
index 9809b0c828..ec1aa187df 100644
--- a/conf/log4j.properties.template
+++ b/conf/log4j.properties.template
@@ -28,8 +28,8 @@ log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}:
log4j.logger.org.apache.spark.repl.Main=WARN
# Settings to quiet third party logs that are too verbose
-log4j.logger.org.spark-project.jetty=WARN
-log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR
+log4j.logger.org.spark_project.jetty=WARN
+log4j.logger.org.spark_project.jetty.util.component.AbstractLifeCycle=ERROR
log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
log4j.logger.org.apache.parquet=ERROR
diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java
deleted file mode 100644
index 23bc9a2e81..0000000000
--- a/core/src/main/java/org/apache/spark/JavaSparkListener.java
+++ /dev/null
@@ -1,88 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark;
-
-import org.apache.spark.scheduler.*;
-
-/**
- * Java clients should extend this class instead of implementing
- * SparkListener directly. This is to prevent java clients
- * from breaking when new events are added to the SparkListener
- * trait.
- *
- * This is a concrete class instead of abstract to enforce
- * new events get added to both the SparkListener and this adapter
- * in lockstep.
- */
-public class JavaSparkListener implements SparkListener {
-
- @Override
- public void onStageCompleted(SparkListenerStageCompleted stageCompleted) { }
-
- @Override
- public void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) { }
-
- @Override
- public void onTaskStart(SparkListenerTaskStart taskStart) { }
-
- @Override
- public void onTaskGettingResult(SparkListenerTaskGettingResult taskGettingResult) { }
-
- @Override
- public void onTaskEnd(SparkListenerTaskEnd taskEnd) { }
-
- @Override
- public void onJobStart(SparkListenerJobStart jobStart) { }
-
- @Override
- public void onJobEnd(SparkListenerJobEnd jobEnd) { }
-
- @Override
- public void onEnvironmentUpdate(SparkListenerEnvironmentUpdate environmentUpdate) { }
-
- @Override
- public void onBlockManagerAdded(SparkListenerBlockManagerAdded blockManagerAdded) { }
-
- @Override
- public void onBlockManagerRemoved(SparkListenerBlockManagerRemoved blockManagerRemoved) { }
-
- @Override
- public void onUnpersistRDD(SparkListenerUnpersistRDD unpersistRDD) { }
-
- @Override
- public void onApplicationStart(SparkListenerApplicationStart applicationStart) { }
-
- @Override
- public void onApplicationEnd(SparkListenerApplicationEnd applicationEnd) { }
-
- @Override
- public void onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate executorMetricsUpdate) { }
-
- @Override
- public void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { }
-
- @Override
- public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { }
-
- @Override
- public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { }
-
- @Override
- public void onOtherEvent(SparkListenerEvent event) { }
-
-}
diff --git a/core/src/main/java/org/apache/spark/SparkExecutorInfo.java b/core/src/main/java/org/apache/spark/SparkExecutorInfo.java
new file mode 100644
index 0000000000..dc3e826475
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/SparkExecutorInfo.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark;
+
+import java.io.Serializable;
+
+/**
+ * Exposes information about Spark Executors.
+ *
+ * This interface is not designed to be implemented outside of Spark. We may add additional methods
+ * which may break binary compatibility with outside implementations.
+ */
+public interface SparkExecutorInfo extends Serializable {
+ String host();
+ int port();
+ long cacheSize();
+ int numRunningTasks();
+}
diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java
index e6b24afd88..97eed611e8 100644
--- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java
+++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java
@@ -28,7 +28,7 @@ import org.apache.spark.scheduler.*;
* this was a concrete Scala class, default implementations of new event handlers would be inherited
* from the SparkListener trait).
*/
-public class SparkFirehoseListener implements SparkListener {
+public class SparkFirehoseListener implements SparkListenerInterface {
public void onEvent(SparkListenerEvent event) { }
diff --git a/core/src/main/java/org/apache/spark/api/java/StorageLevels.java b/core/src/main/java/org/apache/spark/api/java/StorageLevels.java
index 666c797738..3fcb52f615 100644
--- a/core/src/main/java/org/apache/spark/api/java/StorageLevels.java
+++ b/core/src/main/java/org/apache/spark/api/java/StorageLevels.java
@@ -34,13 +34,13 @@ public class StorageLevels {
public static final StorageLevel MEMORY_AND_DISK_2 = create(true, true, false, true, 2);
public static final StorageLevel MEMORY_AND_DISK_SER = create(true, true, false, false, 1);
public static final StorageLevel MEMORY_AND_DISK_SER_2 = create(true, true, false, false, 2);
- public static final StorageLevel OFF_HEAP = create(false, false, true, false, 1);
+ public static final StorageLevel OFF_HEAP = create(true, true, true, false, 1);
/**
* Create a new StorageLevel object.
* @param useDisk saved to disk, if true
- * @param useMemory saved to memory, if true
- * @param useOffHeap saved to Tachyon, if true
+ * @param useMemory saved to on-heap memory, if true
+ * @param useOffHeap saved to off-heap memory, if true
* @param deserialized saved as deserialized objects, if true
* @param replication replication factor
*/
diff --git a/core/src/main/scala/org/apache/spark/io/LZ4BlockInputStream.java b/core/src/main/java/org/apache/spark/io/LZ4BlockInputStream.java
index 27b6f0d4a3..8783b5f56e 100644
--- a/core/src/main/scala/org/apache/spark/io/LZ4BlockInputStream.java
+++ b/core/src/main/java/org/apache/spark/io/LZ4BlockInputStream.java
@@ -20,20 +20,17 @@ import java.io.IOException;
import java.io.InputStream;
import java.util.zip.Checksum;
-import net.jpountz.lz4.LZ4BlockOutputStream;
import net.jpountz.lz4.LZ4Exception;
import net.jpountz.lz4.LZ4Factory;
import net.jpountz.lz4.LZ4FastDecompressor;
import net.jpountz.util.SafeUtils;
-import net.jpountz.xxhash.StreamingXXHash32;
-import net.jpountz.xxhash.XXHash32;
import net.jpountz.xxhash.XXHashFactory;
/**
* {@link InputStream} implementation to decode data written with
- * {@link LZ4BlockOutputStream}. This class is not thread-safe and does not
+ * {@link net.jpountz.lz4.LZ4BlockOutputStream}. This class is not thread-safe and does not
* support {@link #mark(int)}/{@link #reset()}.
- * @see LZ4BlockOutputStream
+ * @see net.jpountz.lz4.LZ4BlockOutputStream
*
* This is based on net.jpountz.lz4.LZ4BlockInputStream
*
@@ -90,12 +87,13 @@ public final class LZ4BlockInputStream extends FilterInputStream {
}
/**
- * Create a new instance using {@link XXHash32} for checksuming.
+ * Create a new instance using {@link net.jpountz.xxhash.XXHash32} for checksuming.
* @see #LZ4BlockInputStream(InputStream, LZ4FastDecompressor, Checksum)
- * @see StreamingXXHash32#asChecksum()
+ * @see net.jpountz.xxhash.StreamingXXHash32#asChecksum()
*/
public LZ4BlockInputStream(InputStream in, LZ4FastDecompressor decompressor) {
- this(in, decompressor, XXHashFactory.fastestInstance().newStreamingHash32(DEFAULT_SEED).asChecksum());
+ this(in, decompressor,
+ XXHashFactory.fastestInstance().newStreamingHash32(DEFAULT_SEED).asChecksum());
}
/**
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index 81ee7ab58a..3c2980e442 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -215,8 +215,6 @@ final class ShuffleExternalSorter extends MemoryConsumer {
}
}
- inMemSorter.reset();
-
if (!isLastFile) { // i.e. this is a spill file
// The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records
// are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter
@@ -255,6 +253,10 @@ final class ShuffleExternalSorter extends MemoryConsumer {
writeSortedFile(false);
final long spillSize = freeMemory();
+ inMemSorter.reset();
+ // Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the
+ // records. Otherwise, if the task is over allocated memory, then without freeing the memory pages,
+ // we might not be able to get memory for the pointer array.
taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
return spillSize;
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
index fe79ff0e30..76b0e6a304 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
@@ -51,9 +51,12 @@ final class ShuffleInMemorySorter {
*/
private int pos = 0;
+ private int initialSize;
+
ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize) {
this.consumer = consumer;
assert (initialSize > 0);
+ this.initialSize = initialSize;
this.array = consumer.allocateArray(initialSize);
this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE);
}
@@ -70,6 +73,10 @@ final class ShuffleInMemorySorter {
}
public void reset() {
+ if (consumer != null) {
+ consumer.freeArray(array);
+ this.array = consumer.allocateArray(initialSize);
+ }
pos = 0;
}
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 9aacb084f6..6807710f9f 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -56,9 +56,10 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter;
* Bytes 4 to 8: len(k)
* Bytes 8 to 8 + len(k): key data
* Bytes 8 + len(k) to 8 + len(k) + len(v): value data
+ * Bytes 8 + len(k) + len(v) to 8 + len(k) + len(v) + 8: pointer to next pair
*
* This means that the first four bytes store the entire record (key + value) length. This format
- * is consistent with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter},
+ * is compatible with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter},
* so we can pass records from this map directly into the sorter to sort records in place.
*/
public final class BytesToBytesMap extends MemoryConsumer {
@@ -132,7 +133,12 @@ public final class BytesToBytesMap extends MemoryConsumer {
/**
* Number of keys defined in the map.
*/
- private int numElements;
+ private int numKeys;
+
+ /**
+ * Number of values defined in the map. A key could have multiple values.
+ */
+ private int numValues;
/**
* The map will be expanded once the number of keys exceeds this threshold.
@@ -223,7 +229,12 @@ public final class BytesToBytesMap extends MemoryConsumer {
/**
* Returns the number of keys defined in the map.
*/
- public int numElements() { return numElements; }
+ public int numKeys() { return numKeys; }
+
+ /**
+ * Returns the number of values defined in the map. A key could have multiple values.
+ */
+ public int numValues() { return numValues; }
public final class MapIterator implements Iterator<Location> {
@@ -311,7 +322,8 @@ public final class BytesToBytesMap extends MemoryConsumer {
if (currentPage != null) {
int totalLength = Platform.getInt(pageBaseObject, offsetInPage);
loc.with(currentPage, offsetInPage);
- offsetInPage += 4 + totalLength;
+ // [total size] [key size] [key] [value] [pointer to next]
+ offsetInPage += 4 + totalLength + 8;
recordsInPage --;
return loc;
} else {
@@ -361,7 +373,7 @@ public final class BytesToBytesMap extends MemoryConsumer {
while (numRecords > 0) {
int length = Platform.getInt(base, offset);
writer.write(base, offset + 4, length, 0);
- offset += 4 + length;
+ offset += 4 + length + 8;
numRecords--;
}
writer.close();
@@ -395,7 +407,7 @@ public final class BytesToBytesMap extends MemoryConsumer {
* `lookup()`, the behavior of the returned iterator is undefined.
*/
public MapIterator iterator() {
- return new MapIterator(numElements, loc, false);
+ return new MapIterator(numValues, loc, false);
}
/**
@@ -409,7 +421,7 @@ public final class BytesToBytesMap extends MemoryConsumer {
* `lookup()`, the behavior of the returned iterator is undefined.
*/
public MapIterator destructiveIterator() {
- return new MapIterator(numElements, loc, true);
+ return new MapIterator(numValues, loc, true);
}
/**
@@ -560,6 +572,20 @@ public final class BytesToBytesMap extends MemoryConsumer {
}
/**
+ * Find the next pair that has the same key as current one.
+ */
+ public boolean nextValue() {
+ assert isDefined;
+ long nextAddr = Platform.getLong(baseObject, valueOffset + valueLength);
+ if (nextAddr == 0) {
+ return false;
+ } else {
+ updateAddressesAndSizes(nextAddr);
+ return true;
+ }
+ }
+
+ /**
* Returns the memory page that contains the current record.
* This is only valid if this is returned by {@link BytesToBytesMap#iterator()}.
*/
@@ -625,10 +651,9 @@ public final class BytesToBytesMap extends MemoryConsumer {
}
/**
- * Store a new key and value. This method may only be called once for a given key; if you want
- * to update the value associated with a key, then you can directly manipulate the bytes stored
- * at the value address. The return value indicates whether the put succeeded or whether it
- * failed because additional memory could not be acquired.
+ * Append a new value for the key. This method could be called multiple times for a given key.
+ * The return value indicates whether the put succeeded or whether it failed because additional
+ * memory could not be acquired.
* <p>
* It is only valid to call this method immediately after calling `lookup()` using the same key.
* </p>
@@ -637,7 +662,7 @@ public final class BytesToBytesMap extends MemoryConsumer {
* </p>
* <p>
* After calling this method, calls to `get[Key|Value]Address()` and `get[Key|Value]Length`
- * will return information on the data stored by this `putNewKey` call.
+ * will return information on the data stored by this `append` call.
* </p>
* <p>
* As an example usage, here's the proper way to store a new key:
@@ -645,7 +670,7 @@ public final class BytesToBytesMap extends MemoryConsumer {
* <pre>
* Location loc = map.lookup(keyBase, keyOffset, keyLength);
* if (!loc.isDefined()) {
- * if (!loc.putNewKey(keyBase, keyOffset, keyLength, ...)) {
+ * if (!loc.append(keyBase, keyOffset, keyLength, ...)) {
* // handle failure to grow map (by spilling, for example)
* }
* }
@@ -657,26 +682,23 @@ public final class BytesToBytesMap extends MemoryConsumer {
* @return true if the put() was successful and false if the put() failed because memory could
* not be acquired.
*/
- public boolean putNewKey(Object keyBase, long keyOffset, int keyLength,
- Object valueBase, long valueOffset, int valueLength) {
- assert (!isDefined) : "Can only set value once for a key";
- assert (keyLength % 8 == 0);
- assert (valueLength % 8 == 0);
- assert(longArray != null);
+ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff, int vlen) {
+ assert (klen % 8 == 0);
+ assert (vlen % 8 == 0);
+ assert (longArray != null);
-
- if (numElements == MAX_CAPACITY
+ if (numKeys == MAX_CAPACITY
// The map could be reused from last spill (because of no enough memory to grow),
// then we don't try to grow again if hit the `growthThreshold`.
- || !canGrowArray && numElements > growthThreshold) {
+ || !canGrowArray && numKeys > growthThreshold) {
return false;
}
// Here, we'll copy the data into our data pages. Because we only store a relative offset from
// the key address instead of storing the absolute address of the value, the key and value
// must be stored in the same memory page.
- // (8 byte key length) (key) (value)
- final long recordLength = 8 + keyLength + valueLength;
+ // (8 byte key length) (key) (value) (8 byte pointer to next value)
+ final long recordLength = 8 + klen + vlen + 8;
if (currentPage == null || currentPage.size() - pageCursor < recordLength) {
if (!acquireNewPage(recordLength + 4L)) {
return false;
@@ -687,30 +709,36 @@ public final class BytesToBytesMap extends MemoryConsumer {
final Object base = currentPage.getBaseObject();
long offset = currentPage.getBaseOffset() + pageCursor;
final long recordOffset = offset;
- Platform.putInt(base, offset, keyLength + valueLength + 4);
- Platform.putInt(base, offset + 4, keyLength);
+ Platform.putInt(base, offset, klen + vlen + 4);
+ Platform.putInt(base, offset + 4, klen);
offset += 8;
- Platform.copyMemory(keyBase, keyOffset, base, offset, keyLength);
- offset += keyLength;
- Platform.copyMemory(valueBase, valueOffset, base, offset, valueLength);
+ Platform.copyMemory(kbase, koff, base, offset, klen);
+ offset += klen;
+ Platform.copyMemory(vbase, voff, base, offset, vlen);
+ offset += vlen;
+ // put this value at the beginning of the list
+ Platform.putLong(base, offset, isDefined ? longArray.get(pos * 2) : 0);
// --- Update bookkeeping data structures ----------------------------------------------------
offset = currentPage.getBaseOffset();
Platform.putInt(base, offset, Platform.getInt(base, offset) + 1);
pageCursor += recordLength;
- numElements++;
final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset(
currentPage, recordOffset);
longArray.set(pos * 2, storedKeyAddress);
- longArray.set(pos * 2 + 1, keyHashcode);
updateAddressesAndSizes(storedKeyAddress);
- isDefined = true;
+ numValues++;
+ if (!isDefined) {
+ numKeys++;
+ longArray.set(pos * 2 + 1, keyHashcode);
+ isDefined = true;
- if (numElements > growthThreshold && longArray.size() < MAX_CAPACITY) {
- try {
- growAndRehash();
- } catch (OutOfMemoryError oom) {
- canGrowArray = false;
+ if (numKeys > growthThreshold && longArray.size() < MAX_CAPACITY) {
+ try {
+ growAndRehash();
+ } catch (OutOfMemoryError oom) {
+ canGrowArray = false;
+ }
}
}
return true;
@@ -866,7 +894,8 @@ public final class BytesToBytesMap extends MemoryConsumer {
* Reset this map to initialized state.
*/
public void reset() {
- numElements = 0;
+ numKeys = 0;
+ numValues = 0;
longArray.zeroOut();
while (dataPages.size() > 0) {
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index cf7c9a299f..dc9a8db9c5 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -200,14 +200,17 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix());
}
spillWriter.close();
-
- inMemSorter.reset();
}
final long spillSize = freeMemory();
// Note that this is more-or-less going to be a multiple of the page size, so wasted space in
// pages will currently be counted as memory spilled even though that space isn't actually
// written to disk. This also counts the space needed to store the sorter's pointer array.
+ inMemSorter.reset();
+ // Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the
+ // records. Otherwise, if the task is over allocated memory, then without freeing the memory pages,
+ // we might not be able to get memory for the pointer array.
+
taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
return spillSize;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index 145c3a1950..01eae0e8dc 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -84,6 +84,8 @@ public final class UnsafeInMemorySorter {
*/
private int pos = 0;
+ private long initialSize;
+
public UnsafeInMemorySorter(
final MemoryConsumer consumer,
final TaskMemoryManager memoryManager,
@@ -102,6 +104,7 @@ public final class UnsafeInMemorySorter {
LongArray array) {
this.consumer = consumer;
this.memoryManager = memoryManager;
+ this.initialSize = array.size();
if (recordComparator != null) {
this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
@@ -123,6 +126,10 @@ public final class UnsafeInMemorySorter {
}
public void reset() {
+ if (consumer != null) {
+ consumer.freeArray(array);
+ this.array = consumer.allocateArray(initialSize);
+ }
pos = 0;
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
index 2b1c860e55..01aed95878 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
@@ -45,7 +45,7 @@ final class UnsafeSorterSpillMerger {
}
}
};
- priorityQueue = new PriorityQueue<UnsafeSorterIterator>(numSpills, comparator);
+ priorityQueue = new PriorityQueue<>(numSpills, comparator);
}
/**
diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties
index 0750488e4a..89a7963a86 100644
--- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties
+++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties
@@ -28,8 +28,8 @@ log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}:
log4j.logger.org.apache.spark.repl.Main=WARN
# Settings to quiet third party logs that are too verbose
-log4j.logger.org.spark-project.jetty=WARN
-log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR
+log4j.logger.org.spark_project.jetty=WARN
+log4j.logger.org.spark_project.jetty.util.component.AbstractLifeCycle=ERROR
log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index 48f86d1536..47dd9162a1 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -106,21 +106,22 @@ pre {
line-height: 18px;
padding: 6px;
margin: 0;
+ word-break: break-word;
border-radius: 3px;
}
.stage-details {
- max-height: 100px;
overflow-y: auto;
margin: 0;
+ display: block;
transition: max-height 0.25s ease-out, padding 0.25s ease-out;
}
.stage-details.collapsed {
- max-height: 0;
padding-top: 0;
padding-bottom: 0;
border: none;
+ display: none;
}
.description-input {
@@ -143,14 +144,15 @@ pre {
max-height: 300px;
overflow-y: auto;
margin: 0;
+ display: block;
transition: max-height 0.25s ease-out, padding 0.25s ease-out;
}
.stacktrace-details.collapsed {
- max-height: 0;
padding-top: 0;
padding-bottom: 0;
border: none;
+ display: none;
}
span.expand-additional-metrics, span.expand-dag-viz {
diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
index 8fc657c5eb..76692ccec8 100644
--- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
@@ -278,9 +278,9 @@ private object ContextCleaner {
* Listener class used for testing when any item has been cleaned by the Cleaner class.
*/
private[spark] trait CleanerListener {
- def rddCleaned(rddId: Int)
- def shuffleCleaned(shuffleId: Int)
- def broadcastCleaned(broadcastId: Long)
- def accumCleaned(accId: Long)
- def checkpointCleaned(rddId: Long)
+ def rddCleaned(rddId: Int): Unit
+ def shuffleCleaned(shuffleId: Int): Unit
+ def broadcastCleaned(broadcastId: Long): Unit
+ def accumCleaned(accId: Long): Unit
+ def checkpointCleaned(rddId: Long): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
index 842bfdbadc..8baddf45bf 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
@@ -23,6 +23,10 @@ package org.apache.spark
*/
private[spark] trait ExecutorAllocationClient {
+
+ /** Get the list of currently active executors */
+ private[spark] def getExecutorIds(): Seq[String]
+
/**
* Update the cluster manager on our scheduling needs. Three bits of information are included
* to help it make decisions.
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
index 509f5082f1..882d2b21cf 100644
--- a/core/src/main/scala/org/apache/spark/FutureAction.scala
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -41,7 +41,7 @@ trait FutureAction[T] extends Future[T] {
/**
* Cancels the execution of this action.
*/
- def cancel()
+ def cancel(): Unit
/**
* Blocks until this action completes.
@@ -65,7 +65,7 @@ trait FutureAction[T] extends Future[T] {
* When this action is completed, either through an exception, or a value, applies the provided
* function.
*/
- def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext)
+ def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext): Unit
/**
* Returns whether the action has already been completed with a value or an exception.
@@ -156,16 +156,16 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:
/**
- * Handle via which a "run" function passed to a [[ComplexFutureAction]]
- * can submit jobs for execution.
- */
+ * Handle via which a "run" function passed to a [[ComplexFutureAction]]
+ * can submit jobs for execution.
+ */
@DeveloperApi
trait JobSubmitter {
/**
- * Submit a job for execution and return a FutureAction holding the result.
- * This is a wrapper around the same functionality provided by SparkContext
- * to enable cancellation.
- */
+ * Submit a job for execution and return a FutureAction holding the result.
+ * This is a wrapper around the same functionality provided by SparkContext
+ * to enable cancellation.
+ */
def submitJob[T, U, R](
rdd: RDD[T],
processPartition: Iterator[T] => U,
diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
index e8748dd80a..2bdbd3fae9 100644
--- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
+++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
@@ -56,7 +56,7 @@ private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean)
* Lives in the driver to receive heartbeats from executors..
*/
private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock)
- extends ThreadSafeRpcEndpoint with SparkListener with Logging {
+ extends SparkListener with ThreadSafeRpcEndpoint with Logging {
def this(sc: SparkContext) {
this(sc, new SystemClock)
@@ -220,6 +220,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock)
}
}
-object HeartbeatReceiver {
+
+private[spark] object HeartbeatReceiver {
val ENDPOINT_NAME = "HeartbeatReceiver"
}
diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala
index 9fad1f6786..982b6d6b61 100644
--- a/core/src/main/scala/org/apache/spark/HttpServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpServer.scala
@@ -25,6 +25,7 @@ import org.eclipse.jetty.server.Server
import org.eclipse.jetty.server.bio.SocketConnector
import org.eclipse.jetty.server.ssl.SslSocketConnector
import org.eclipse.jetty.servlet.{DefaultServlet, ServletContextHandler, ServletHolder}
+import org.eclipse.jetty.util.component.LifeCycle
import org.eclipse.jetty.util.security.{Constraint, Password}
import org.eclipse.jetty.util.thread.QueuedThreadPool
@@ -155,6 +156,12 @@ private[spark] class HttpServer(
throw new ServerStateException("Server is already stopped")
} else {
server.stop()
+ // Stop the ThreadPool if it supports stop() method (through LifeCycle).
+ // It is needed because stopping the Server won't stop the ThreadPool it uses.
+ val threadPool = server.getThreadPool
+ if (threadPool != null && threadPool.isInstanceOf[LifeCycle]) {
+ threadPool.asInstanceOf[LifeCycle].stop
+ }
port = -1
server = null
}
diff --git a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala
index 7aa9057858..0dd4ec656f 100644
--- a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala
+++ b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala
@@ -187,7 +187,7 @@ private[spark] object InternalAccumulator {
* add to the same set of accumulators. We do this to report the distribution of accumulator
* values across all tasks within each stage.
*/
- def create(sc: SparkContext): Seq[Accumulator[_]] = {
+ def createAll(sc: SparkContext): Seq[Accumulator[_]] = {
val accums = createAll()
accums.foreach { accum =>
Accumulators.register(accum)
diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala
index 30db6ccbf4..719905a2c9 100644
--- a/core/src/main/scala/org/apache/spark/SSLOptions.scala
+++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala
@@ -132,34 +132,35 @@ private[spark] case class SSLOptions(
private[spark] object SSLOptions extends Logging {
- /** Resolves SSLOptions settings from a given Spark configuration object at a given namespace.
- *
- * The following settings are allowed:
- * $ - `[ns].enabled` - `true` or `false`, to enable or disable SSL respectively
- * $ - `[ns].keyStore` - a path to the key-store file; can be relative to the current directory
- * $ - `[ns].keyStorePassword` - a password to the key-store file
- * $ - `[ns].keyPassword` - a password to the private key
- * $ - `[ns].keyStoreType` - the type of the key-store
- * $ - `[ns].needClientAuth` - whether SSL needs client authentication
- * $ - `[ns].trustStore` - a path to the trust-store file; can be relative to the current
- * directory
- * $ - `[ns].trustStorePassword` - a password to the trust-store file
- * $ - `[ns].trustStoreType` - the type of trust-store
- * $ - `[ns].protocol` - a protocol name supported by a particular Java version
- * $ - `[ns].enabledAlgorithms` - a comma separated list of ciphers
- *
- * For a list of protocols and ciphers supported by particular Java versions, you may go to
- * [[https://blogs.oracle.com/java-platform-group/entry/diagnosing_tls_ssl_and_https Oracle
- * blog page]].
- *
- * You can optionally specify the default configuration. If you do, for each setting which is
- * missing in SparkConf, the corresponding setting is used from the default configuration.
- *
- * @param conf Spark configuration object where the settings are collected from
- * @param ns the namespace name
- * @param defaults the default configuration
- * @return [[org.apache.spark.SSLOptions]] object
- */
+ /**
+ * Resolves SSLOptions settings from a given Spark configuration object at a given namespace.
+ *
+ * The following settings are allowed:
+ * $ - `[ns].enabled` - `true` or `false`, to enable or disable SSL respectively
+ * $ - `[ns].keyStore` - a path to the key-store file; can be relative to the current directory
+ * $ - `[ns].keyStorePassword` - a password to the key-store file
+ * $ - `[ns].keyPassword` - a password to the private key
+ * $ - `[ns].keyStoreType` - the type of the key-store
+ * $ - `[ns].needClientAuth` - whether SSL needs client authentication
+ * $ - `[ns].trustStore` - a path to the trust-store file; can be relative to the current
+ * directory
+ * $ - `[ns].trustStorePassword` - a password to the trust-store file
+ * $ - `[ns].trustStoreType` - the type of trust-store
+ * $ - `[ns].protocol` - a protocol name supported by a particular Java version
+ * $ - `[ns].enabledAlgorithms` - a comma separated list of ciphers
+ *
+ * For a list of protocols and ciphers supported by particular Java versions, you may go to
+ * [[https://blogs.oracle.com/java-platform-group/entry/diagnosing_tls_ssl_and_https Oracle
+ * blog page]].
+ *
+ * You can optionally specify the default configuration. If you do, for each setting which is
+ * missing in SparkConf, the corresponding setting is used from the default configuration.
+ *
+ * @param conf Spark configuration object where the settings are collected from
+ * @param ns the namespace name
+ * @param defaults the default configuration
+ * @return [[org.apache.spark.SSLOptions]] object
+ */
def parse(conf: SparkConf, ns: String, defaults: Option[SSLOptions] = None): SSLOptions = {
val enabled = conf.getBoolean(s"$ns.enabled", defaultValue = defaults.exists(_.enabled))
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index 5da2e98f1f..acce6bc24f 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -419,8 +419,10 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
*/
private[spark] def getenv(name: String): String = System.getenv(name)
- /** Checks for illegal or deprecated config settings. Throws an exception for the former. Not
- * idempotent - may mutate this conf object to convert deprecated settings to supported ones. */
+ /**
+ * Checks for illegal or deprecated config settings. Throws an exception for the former. Not
+ * idempotent - may mutate this conf object to convert deprecated settings to supported ones.
+ */
private[spark] def validateSettings() {
if (contains("spark.local.dir")) {
val msg = "In Spark 1.0 and later spark.local.dir will be overridden by the value set by " +
@@ -454,9 +456,9 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
"Set them directly on a SparkConf or in a properties file when using ./bin/spark-submit."
throw new Exception(msg)
}
- if (javaOpts.contains("-Xmx") || javaOpts.contains("-Xms")) {
- val msg = s"$executorOptsKey is not allowed to alter memory settings (was '$javaOpts'). " +
- "Use spark.executor.memory instead."
+ if (javaOpts.contains("-Xmx")) {
+ val msg = s"$executorOptsKey is not allowed to specify max heap memory settings " +
+ s"(was '$javaOpts'). Use spark.executor.memory instead."
throw new Exception(msg)
}
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index dcb41f3a40..e41088f7c8 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -147,8 +147,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
appName: String,
sparkHome: String = null,
jars: Seq[String] = Nil,
- environment: Map[String, String] = Map()) =
- {
+ environment: Map[String, String] = Map()) = {
this(SparkContext.updatedConf(new SparkConf(), master, appName, sparkHome, jars, environment))
}
@@ -603,8 +602,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
/**
- * Set a local property that affects jobs submitted from this thread, such as the
- * Spark fair scheduler pool.
+ * Set a local property that affects jobs submitted from this thread, such as the Spark fair
+ * scheduler pool. User-defined properties may also be set here. These properties are propagated
+ * through to worker tasks and can be accessed there via
+ * [[org.apache.spark.TaskContext#getLocalProperty]].
*/
def setLocalProperty(key: String, value: String) {
if (value == null) {
@@ -722,7 +723,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
(safeEnd - safeStart) / step + 1
}
}
- parallelize(0 until numSlices, numSlices).mapPartitionsWithIndex((i, _) => {
+ parallelize(0 until numSlices, numSlices).mapPartitionsWithIndex { (i, _) =>
val partitionStart = (i * numElements) / numSlices * step + start
val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start
def getSafeMargin(bi: BigInt): Long =
@@ -761,7 +762,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
ret
}
}
- })
+ }
}
/** Distribute a local Scala collection to form an RDD.
@@ -774,9 +775,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
parallelize(seq, numSlices)
}
- /** Distribute a local Scala collection to form an RDD, with one or more
- * location preferences (hostnames of Spark nodes) for each object.
- * Create a new partition for each collection item. */
+ /**
+ * Distribute a local Scala collection to form an RDD, with one or more
+ * location preferences (hostnames of Spark nodes) for each object.
+ * Create a new partition for each collection item.
+ */
def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = withScope {
assertNotStopped()
val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap
@@ -1096,14 +1099,15 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
new NewHadoopRDD(this, fClass, kClass, vClass, jconf)
}
- /** Get an RDD for a Hadoop SequenceFile with given key and value types.
- *
- * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
- * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle
- * operation will create many references to the same object.
- * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first
- * copy them using a `map` function.
- */
+ /**
+ * Get an RDD for a Hadoop SequenceFile with given key and value types.
+ *
+ * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
+ * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle
+ * operation will create many references to the same object.
+ * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first
+ * copy them using a `map` function.
+ */
def sequenceFile[K, V](path: String,
keyClass: Class[K],
valueClass: Class[V],
@@ -1114,14 +1118,15 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions)
}
- /** Get an RDD for a Hadoop SequenceFile with given key and value types.
- *
- * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
- * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle
- * operation will create many references to the same object.
- * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first
- * copy them using a `map` function.
- * */
+ /**
+ * Get an RDD for a Hadoop SequenceFile with given key and value types.
+ *
+ * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
+ * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle
+ * operation will create many references to the same object.
+ * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first
+ * copy them using a `map` function.
+ */
def sequenceFile[K, V](
path: String,
keyClass: Class[K],
@@ -1353,10 +1358,20 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* Register a listener to receive up-calls from events that happen during execution.
*/
@DeveloperApi
- def addSparkListener(listener: SparkListener) {
+ def addSparkListener(listener: SparkListenerInterface) {
listenerBus.addListener(listener)
}
+ private[spark] override def getExecutorIds(): Seq[String] = {
+ schedulerBackend match {
+ case b: CoarseGrainedSchedulerBackend =>
+ b.getExecutorIds()
+ case _ =>
+ logWarning("Requesting executors is only supported in coarse-grained mode")
+ Nil
+ }
+ }
+
/**
* Update the cluster manager on our scheduling needs. Three bits of information are included
* to help it make decisions.
@@ -1994,7 +2009,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
// Use reflection to find the right constructor
val constructors = {
val listenerClass = Utils.classForName(className)
- listenerClass.getConstructors.asInstanceOf[Array[Constructor[_ <: SparkListener]]]
+ listenerClass
+ .getConstructors
+ .asInstanceOf[Array[Constructor[_ <: SparkListenerInterface]]]
}
val constructorTakingSparkConf = constructors.find { c =>
c.getParameterTypes.sameElements(Array(classOf[SparkConf]))
@@ -2002,7 +2019,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
lazy val zeroArgumentConstructor = constructors.find { c =>
c.getParameterTypes.isEmpty
}
- val listener: SparkListener = {
+ val listener: SparkListenerInterface = {
if (constructorTakingSparkConf.isDefined) {
constructorTakingSparkConf.get.newInstance(conf)
} else if (zeroArgumentConstructor.isDefined) {
@@ -2380,9 +2397,8 @@ object SparkContext extends Logging {
} catch {
// TODO: Enumerate the exact reasons why it can fail
// But irrespective of it, it means we cannot proceed !
- case e: Exception => {
+ case e: Exception =>
throw new SparkException("YARN mode not available ?", e)
- }
}
val backend = try {
val clazz =
@@ -2390,9 +2406,8 @@ object SparkContext extends Logging {
val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext])
cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend]
} catch {
- case e: Exception => {
+ case e: Exception =>
throw new SparkException("YARN mode not available ?", e)
- }
}
scheduler.initialize(backend)
(backend, scheduler)
@@ -2404,9 +2419,8 @@ object SparkContext extends Logging {
cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl]
} catch {
- case e: Exception => {
+ case e: Exception =>
throw new SparkException("YARN mode not available ?", e)
- }
}
val backend = try {
@@ -2415,9 +2429,8 @@ object SparkContext extends Logging {
val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext])
cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend]
} catch {
- case e: Exception => {
+ case e: Exception =>
throw new SparkException("YARN mode not available ?", e)
- }
}
scheduler.initialize(backend)
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 700e2cb3f9..3d11db7461 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -101,14 +101,13 @@ class SparkEnv (
// We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the
// current working dir in executor which we do not need to delete.
driverTmpDirToDelete match {
- case Some(path) => {
+ case Some(path) =>
try {
Utils.deleteRecursively(new File(path))
} catch {
case e: Exception =>
logWarning(s"Exception while deleting Spark temp dir: $path", e)
}
- }
case None => // We just need to delete tmp dir created by driver, so do nothing on executor
}
}
@@ -314,7 +313,8 @@ object SparkEnv extends Logging {
UnifiedMemoryManager(conf, numUsableCores)
}
- val blockTransferService = new NettyBlockTransferService(conf, securityManager, numUsableCores)
+ val blockTransferService =
+ new NettyBlockTransferService(conf, securityManager, hostname, numUsableCores)
val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint(
BlockManagerMaster.DRIVER_ENDPOINT_NAME,
diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala
index 34ee3a48f8..52c4656c27 100644
--- a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala
+++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala
@@ -17,6 +17,8 @@
package org.apache.spark
+import org.apache.spark.scheduler.TaskSchedulerImpl
+
/**
* Low-level status reporting APIs for monitoring job and stage progress.
*
@@ -104,4 +106,22 @@ class SparkStatusTracker private[spark] (sc: SparkContext) {
}
}
}
+
+ /**
+ * Returns information of all known executors, including host, port, cacheSize, numRunningTasks.
+ */
+ def getExecutorInfos: Array[SparkExecutorInfo] = {
+ val executorIdToRunningTasks: Map[String, Int] =
+ sc.taskScheduler.asInstanceOf[TaskSchedulerImpl].runningTasksByExecutors()
+
+ sc.getExecutorStorageStatus.map { status =>
+ val bmId = status.blockManagerId
+ new SparkExecutorInfoImpl(
+ bmId.host,
+ bmId.port,
+ status.cacheSize,
+ executorIdToRunningTasks.getOrElse(bmId.executorId, 0)
+ )
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala
index e5c7c8d0db..c1f24a6377 100644
--- a/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala
+++ b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala
@@ -18,18 +18,25 @@
package org.apache.spark
private class SparkJobInfoImpl (
- val jobId: Int,
- val stageIds: Array[Int],
- val status: JobExecutionStatus)
- extends SparkJobInfo
+ val jobId: Int,
+ val stageIds: Array[Int],
+ val status: JobExecutionStatus)
+ extends SparkJobInfo
private class SparkStageInfoImpl(
- val stageId: Int,
- val currentAttemptId: Int,
- val submissionTime: Long,
- val name: String,
- val numTasks: Int,
- val numActiveTasks: Int,
- val numCompletedTasks: Int,
- val numFailedTasks: Int)
- extends SparkStageInfo
+ val stageId: Int,
+ val currentAttemptId: Int,
+ val submissionTime: Long,
+ val name: String,
+ val numTasks: Int,
+ val numActiveTasks: Int,
+ val numCompletedTasks: Int,
+ val numFailedTasks: Int)
+ extends SparkStageInfo
+
+private class SparkExecutorInfoImpl(
+ val host: String,
+ val port: Int,
+ val cacheSize: Long,
+ val numRunningTasks: Int)
+ extends SparkExecutorInfo
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 0c1a1dec30..5b2fca4b2d 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -18,6 +18,7 @@
package org.apache.spark
import java.io.Serializable
+import java.util.Properties
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
@@ -64,7 +65,7 @@ object TaskContext {
* An empty task context that does not represent an actual task.
*/
private[spark] def empty(): TaskContextImpl = {
- new TaskContextImpl(0, 0, 0, 0, null, null)
+ new TaskContextImpl(0, 0, 0, 0, null, new Properties, null)
}
}
@@ -170,6 +171,12 @@ abstract class TaskContext extends Serializable {
*/
def taskAttemptId(): Long
+ /**
+ * Get a local property set upstream in the driver, or null if it is missing. See also
+ * [[org.apache.spark.SparkContext.setLocalProperty]].
+ */
+ def getLocalProperty(key: String): String
+
@DeveloperApi
def taskMetrics(): TaskMetrics
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index 87dc7f30e7..8b407f9771 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -17,6 +17,8 @@
package org.apache.spark
+import java.util.Properties
+
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.executor.TaskMetrics
@@ -32,6 +34,7 @@ private[spark] class TaskContextImpl(
override val taskAttemptId: Long,
override val attemptNumber: Int,
override val taskMemoryManager: TaskMemoryManager,
+ localProperties: Properties,
@transient private val metricsSystem: MetricsSystem,
initialAccumulators: Seq[Accumulator[_]] = InternalAccumulator.createAll())
extends TaskContext
@@ -119,6 +122,8 @@ private[spark] class TaskContextImpl(
override def isInterrupted(): Boolean = interrupted
+ override def getLocalProperty(key: String): String = localProperties.getProperty(key)
+
override def getMetricsSources(sourceName: String): Seq[Source] =
metricsSystem.getSourcesByName(sourceName)
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index e080f91f50..2897272a8b 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -461,10 +461,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
fromRDD(rdd.partitionBy(partitioner))
/**
- * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each
- * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
- * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD.
- */
+ * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each
+ * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
+ * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD.
+ */
def join[W](other: JavaPairRDD[K, W], partitioner: Partitioner): JavaPairRDD[K, (V, W)] =
fromRDD(rdd.join(other, partitioner))
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index d362c40b7a..dfd91ae338 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -295,13 +295,14 @@ class JavaSparkContext(val sc: SparkContext)
new JavaRDD(sc.binaryRecords(path, recordLength))
}
- /** Get an RDD for a Hadoop SequenceFile with given key and value types.
- *
- * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
- * record, directly caching the returned RDD will create many references to the same object.
- * If you plan to directly cache Hadoop writable objects, you should first copy them using
- * a `map` function.
- * */
+ /**
+ * Get an RDD for a Hadoop SequenceFile with given key and value types.
+ *
+ * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
+ * record, directly caching the returned RDD will create many references to the same object.
+ * If you plan to directly cache Hadoop writable objects, you should first copy them using
+ * a `map` function.
+ */
def sequenceFile[K, V](path: String,
keyClass: Class[K],
valueClass: Class[V],
@@ -312,13 +313,14 @@ class JavaSparkContext(val sc: SparkContext)
new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass, minPartitions))
}
- /** Get an RDD for a Hadoop SequenceFile.
- *
- * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
- * record, directly caching the returned RDD will create many references to the same object.
- * If you plan to directly cache Hadoop writable objects, you should first copy them using
- * a `map` function.
- */
+ /**
+ * Get an RDD for a Hadoop SequenceFile.
+ *
+ * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
+ * record, directly caching the returned RDD will create many references to the same object.
+ * If you plan to directly cache Hadoop writable objects, you should first copy them using
+ * a `map` function.
+ */
def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]):
JavaPairRDD[K, V] = {
implicit val ctagK: ClassTag[K] = ClassTag(keyClass)
@@ -411,13 +413,14 @@ class JavaSparkContext(val sc: SparkContext)
new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]])
}
- /** Get an RDD for a Hadoop file with an arbitrary InputFormat.
- *
- * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
- * record, directly caching the returned RDD will create many references to the same object.
- * If you plan to directly cache Hadoop writable objects, you should first copy them using
- * a `map` function.
- */
+ /**
+ * Get an RDD for a Hadoop file with an arbitrary InputFormat.
+ *
+ * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
+ * record, directly caching the returned RDD will create many references to the same object.
+ * If you plan to directly cache Hadoop writable objects, you should first copy them using
+ * a `map` function.
+ */
def hadoopFile[K, V, F <: InputFormat[K, V]](
path: String,
inputFormatClass: Class[F],
@@ -431,13 +434,14 @@ class JavaSparkContext(val sc: SparkContext)
new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]])
}
- /** Get an RDD for a Hadoop file with an arbitrary InputFormat
- *
- * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
- * record, directly caching the returned RDD will create many references to the same object.
- * If you plan to directly cache Hadoop writable objects, you should first copy them using
- * a `map` function.
- */
+ /**
+ * Get an RDD for a Hadoop file with an arbitrary InputFormat
+ *
+ * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
+ * record, directly caching the returned RDD will create many references to the same object.
+ * If you plan to directly cache Hadoop writable objects, you should first copy them using
+ * a `map` function.
+ */
def hadoopFile[K, V, F <: InputFormat[K, V]](
path: String,
inputFormatClass: Class[F],
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
index 6f6730690f..6259bead3e 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
@@ -134,11 +134,10 @@ private[python] class JavaToWritableConverter extends Converter[Any, Writable] {
mapWritable.put(convertToWritable(k), convertToWritable(v))
}
mapWritable
- case array: Array[Any] => {
+ case array: Array[Any] =>
val arrayWriteable = new ArrayWritable(classOf[Writable])
arrayWriteable.set(array.map(convertToWritable(_)))
arrayWriteable
- }
case other => throw new SparkException(
s"Data of type ${other.getClass.getName} cannot be used")
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index f423b2ee56..ab5b6c8380 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -59,7 +59,7 @@ private[spark] class PythonRDD(
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
- val runner = new PythonRunner(func, bufferSize, reuse_worker)
+ val runner = PythonRunner(func, bufferSize, reuse_worker)
runner.compute(firstParent.iterator(split, context), split.index, context)
}
}
@@ -78,17 +78,41 @@ private[spark] case class PythonFunction(
accumulator: Accumulator[JList[Array[Byte]]])
/**
- * A helper class to run Python UDFs in Spark.
+ * A wrapper for chained Python functions (from bottom to top).
+ * @param funcs
+ */
+private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction])
+
+private[spark] object PythonRunner {
+ def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = {
+ new PythonRunner(
+ Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Array(Array(0)))
+ }
+}
+
+/**
+ * A helper class to run Python mapPartition/UDFs in Spark.
+ *
+ * funcs is a list of independent Python functions, each one of them is a list of chained Python
+ * functions (from bottom to top).
*/
private[spark] class PythonRunner(
- func: PythonFunction,
+ funcs: Seq[ChainedPythonFunctions],
bufferSize: Int,
- reuse_worker: Boolean)
+ reuse_worker: Boolean,
+ isUDF: Boolean,
+ argOffsets: Array[Array[Int]])
extends Logging {
- private val envVars = func.envVars
- private val pythonExec = func.pythonExec
- private val accumulator = func.accumulator
+ require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs")
+
+ // All the Python functions should have the same exec, version and envvars.
+ private val envVars = funcs.head.funcs.head.envVars
+ private val pythonExec = funcs.head.funcs.head.pythonExec
+ private val pythonVer = funcs.head.funcs.head.pythonVer
+
+ // TODO: support accumulator in multiple UDF
+ private val accumulator = funcs.head.funcs.head.accumulator
def compute(
inputIterator: Iterator[_],
@@ -228,10 +252,8 @@ private[spark] class PythonRunner(
@volatile private var _exception: Exception = null
- private val pythonVer = func.pythonVer
- private val pythonIncludes = func.pythonIncludes
- private val broadcastVars = func.broadcastVars
- private val command = func.command
+ private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
+ private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))
setDaemon(true)
@@ -256,13 +278,13 @@ private[spark] class PythonRunner(
// sparkFilesDir
PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut)
// Python includes (*.zip and *.egg files)
- dataOut.writeInt(pythonIncludes.size())
- for (include <- pythonIncludes.asScala) {
+ dataOut.writeInt(pythonIncludes.size)
+ for (include <- pythonIncludes) {
PythonRDD.writeUTF(include, dataOut)
}
// Broadcast variables
val oldBids = PythonRDD.getWorkerBroadcasts(worker)
- val newBids = broadcastVars.asScala.map(_.id).toSet
+ val newBids = broadcastVars.map(_.id).toSet
// number of different broadcasts
val toRemove = oldBids.diff(newBids)
val cnt = toRemove.size + newBids.diff(oldBids).size
@@ -272,7 +294,7 @@ private[spark] class PythonRunner(
dataOut.writeLong(- bid - 1) // bid >= 0
oldBids.remove(bid)
}
- for (broadcast <- broadcastVars.asScala) {
+ for (broadcast <- broadcastVars) {
if (!oldBids.contains(broadcast.id)) {
// send new broadcast
dataOut.writeLong(broadcast.id)
@@ -282,8 +304,26 @@ private[spark] class PythonRunner(
}
dataOut.flush()
// Serialized command:
- dataOut.writeInt(command.length)
- dataOut.write(command)
+ if (isUDF) {
+ dataOut.writeInt(1)
+ dataOut.writeInt(funcs.length)
+ funcs.zip(argOffsets).foreach { case (chained, offsets) =>
+ dataOut.writeInt(offsets.length)
+ offsets.foreach { offset =>
+ dataOut.writeInt(offset)
+ }
+ dataOut.writeInt(chained.funcs.length)
+ chained.funcs.foreach { f =>
+ dataOut.writeInt(f.command.length)
+ dataOut.write(f.command)
+ }
+ }
+ } else {
+ dataOut.writeInt(0)
+ val command = funcs.head.funcs.head.command
+ dataOut.writeInt(command.length)
+ dataOut.write(command)
+ }
// Data values
PythonRDD.writeIteratorToStream(inputIterator, dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
@@ -413,6 +453,10 @@ private[spark] object PythonRDD extends Logging {
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
}
+ def toLocalIteratorAndServe[T](rdd: RDD[T]): Int = {
+ serveIterator(rdd.toLocalIterator, s"serve toLocalIterator")
+ }
+
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
JavaRDD[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
@@ -426,7 +470,7 @@ private[spark] object PythonRDD extends Logging {
objs.append(obj)
}
} catch {
- case eof: EOFException => {}
+ case eof: EOFException => // No-op
}
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
} finally {
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
index 588a57e65f..606ba6ef86 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -17,21 +17,16 @@
package org.apache.spark.api.r
-import java.io._
-import java.net.{InetAddress, ServerSocket}
-import java.util.{Arrays, Map => JMap}
+import java.util.{Map => JMap}
import scala.collection.JavaConverters._
-import scala.io.Source
import scala.reflect.ClassTag
-import scala.util.Try
import org.apache.spark._
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
-import org.apache.spark.util.Utils
private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
parent: RDD[T],
@@ -42,188 +37,16 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
packageNames: Array[Byte],
broadcastVars: Array[Broadcast[Object]])
extends RDD[U](parent) with Logging {
- protected var dataStream: DataInputStream = _
- private var bootTime: Double = _
override def getPartitions: Array[Partition] = parent.partitions
override def compute(partition: Partition, context: TaskContext): Iterator[U] = {
-
- // Timing start
- bootTime = System.currentTimeMillis / 1000.0
+ val runner = new RRunner[U](
+ func, deserializer, serializer, packageNames, broadcastVars, numPartitions)
// The parent may be also an RRDD, so we should launch it first.
val parentIterator = firstParent[T].iterator(partition, context)
- // we expect two connections
- val serverSocket = new ServerSocket(0, 2, InetAddress.getByName("localhost"))
- val listenPort = serverSocket.getLocalPort()
-
- // The stdout/stderr is shared by multiple tasks, because we use one daemon
- // to launch child process as worker.
- val errThread = RRDD.createRWorker(listenPort)
-
- // We use two sockets to separate input and output, then it's easy to manage
- // the lifecycle of them to avoid deadlock.
- // TODO: optimize it to use one socket
-
- // the socket used to send out the input of task
- serverSocket.setSoTimeout(10000)
- val inSocket = serverSocket.accept()
- startStdinThread(inSocket.getOutputStream(), parentIterator, partition.index)
-
- // the socket used to receive the output of task
- val outSocket = serverSocket.accept()
- val inputStream = new BufferedInputStream(outSocket.getInputStream)
- dataStream = new DataInputStream(inputStream)
- serverSocket.close()
-
- try {
-
- return new Iterator[U] {
- def next(): U = {
- val obj = _nextObj
- if (hasNext) {
- _nextObj = read()
- }
- obj
- }
-
- var _nextObj = read()
-
- def hasNext(): Boolean = {
- val hasMore = (_nextObj != null)
- if (!hasMore) {
- dataStream.close()
- }
- hasMore
- }
- }
- } catch {
- case e: Exception =>
- throw new SparkException("R computation failed with\n " + errThread.getLines())
- }
- }
-
- /**
- * Start a thread to write RDD data to the R process.
- */
- private def startStdinThread[T](
- output: OutputStream,
- iter: Iterator[T],
- partition: Int): Unit = {
-
- val env = SparkEnv.get
- val taskContext = TaskContext.get()
- val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
- val stream = new BufferedOutputStream(output, bufferSize)
-
- new Thread("writer for R") {
- override def run(): Unit = {
- try {
- SparkEnv.set(env)
- TaskContext.setTaskContext(taskContext)
- val dataOut = new DataOutputStream(stream)
- dataOut.writeInt(partition)
-
- SerDe.writeString(dataOut, deserializer)
- SerDe.writeString(dataOut, serializer)
-
- dataOut.writeInt(packageNames.length)
- dataOut.write(packageNames)
-
- dataOut.writeInt(func.length)
- dataOut.write(func)
-
- dataOut.writeInt(broadcastVars.length)
- broadcastVars.foreach { broadcast =>
- // TODO(shivaram): Read a Long in R to avoid this cast
- dataOut.writeInt(broadcast.id.toInt)
- // TODO: Pass a byte array from R to avoid this cast ?
- val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]]
- dataOut.writeInt(broadcastByteArr.length)
- dataOut.write(broadcastByteArr)
- }
-
- dataOut.writeInt(numPartitions)
-
- if (!iter.hasNext) {
- dataOut.writeInt(0)
- } else {
- dataOut.writeInt(1)
- }
-
- val printOut = new PrintStream(stream)
-
- def writeElem(elem: Any): Unit = {
- if (deserializer == SerializationFormats.BYTE) {
- val elemArr = elem.asInstanceOf[Array[Byte]]
- dataOut.writeInt(elemArr.length)
- dataOut.write(elemArr)
- } else if (deserializer == SerializationFormats.ROW) {
- dataOut.write(elem.asInstanceOf[Array[Byte]])
- } else if (deserializer == SerializationFormats.STRING) {
- // write string(for StringRRDD)
- // scalastyle:off println
- printOut.println(elem)
- // scalastyle:on println
- }
- }
-
- for (elem <- iter) {
- elem match {
- case (key, value) =>
- writeElem(key)
- writeElem(value)
- case _ =>
- writeElem(elem)
- }
- }
- stream.flush()
- } catch {
- // TODO: We should propogate this error to the task thread
- case e: Exception =>
- logError("R Writer thread got an exception", e)
- } finally {
- Try(output.close())
- }
- }
- }.start()
- }
-
- protected def readData(length: Int): U
-
- protected def read(): U = {
- try {
- val length = dataStream.readInt()
-
- length match {
- case SpecialLengths.TIMING_DATA =>
- // Timing data from R worker
- val boot = dataStream.readDouble - bootTime
- val init = dataStream.readDouble
- val broadcast = dataStream.readDouble
- val input = dataStream.readDouble
- val compute = dataStream.readDouble
- val output = dataStream.readDouble
- logInfo(
- ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " +
- "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " +
- "total = %.3f s").format(
- boot,
- init,
- broadcast,
- input,
- compute,
- output,
- boot + init + broadcast + input + compute + output))
- read()
- case length if length >= 0 =>
- readData(length)
- }
- } catch {
- case eof: EOFException =>
- throw new SparkException("R worker exited unexpectedly (cranshed)", eof)
- }
+ runner.compute(parentIterator, partition.index, context)
}
}
@@ -242,19 +65,6 @@ private class PairwiseRRDD[T: ClassTag](
parent, numPartitions, hashFunc, deserializer,
SerializationFormats.BYTE, packageNames,
broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
-
- override protected def readData(length: Int): (Int, Array[Byte]) = {
- length match {
- case length if length == 2 =>
- val hashedKey = dataStream.readInt()
- val contentPairsLength = dataStream.readInt()
- val contentPairs = new Array[Byte](contentPairsLength)
- dataStream.readFully(contentPairs)
- (hashedKey, contentPairs)
- case _ => null
- }
- }
-
lazy val asJavaPairRDD : JavaPairRDD[Int, Array[Byte]] = JavaPairRDD.fromRDD(this)
}
@@ -271,17 +81,6 @@ private class RRDD[T: ClassTag](
extends BaseRRDD[T, Array[Byte]](
parent, -1, func, deserializer, serializer, packageNames,
broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
-
- override protected def readData(length: Int): Array[Byte] = {
- length match {
- case length if length > 0 =>
- val obj = new Array[Byte](length)
- dataStream.readFully(obj)
- obj
- case _ => null
- }
- }
-
lazy val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
}
@@ -297,55 +96,10 @@ private class StringRRDD[T: ClassTag](
extends BaseRRDD[T, String](
parent, -1, func, deserializer, SerializationFormats.STRING, packageNames,
broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
-
- override protected def readData(length: Int): String = {
- length match {
- case length if length > 0 =>
- SerDe.readStringBytes(dataStream, length)
- case _ => null
- }
- }
-
lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this)
}
-private object SpecialLengths {
- val TIMING_DATA = -1
-}
-
-private[r] class BufferedStreamThread(
- in: InputStream,
- name: String,
- errBufferSize: Int) extends Thread(name) with Logging {
- val lines = new Array[String](errBufferSize)
- var lineIdx = 0
- override def run() {
- for (line <- Source.fromInputStream(in).getLines) {
- synchronized {
- lines(lineIdx) = line
- lineIdx = (lineIdx + 1) % errBufferSize
- }
- logInfo(line)
- }
- }
-
- def getLines(): String = synchronized {
- (0 until errBufferSize).filter { x =>
- lines((x + lineIdx) % errBufferSize) != null
- }.map { x =>
- lines((x + lineIdx) % errBufferSize)
- }.mkString("\n")
- }
-}
-
private[r] object RRDD {
- // Because forking processes from Java is expensive, we prefer to launch
- // a single R daemon (daemon.R) and tell it to fork new workers for our tasks.
- // This daemon currently only works on UNIX-based systems now, so we should
- // also fall back to launching workers (worker.R) directly.
- private[this] var errThread: BufferedStreamThread = _
- private[this] var daemonChannel: DataOutputStream = _
-
def createSparkContext(
master: String,
appName: String,
@@ -353,7 +107,6 @@ private[r] object RRDD {
jars: Array[String],
sparkEnvirMap: JMap[Object, Object],
sparkExecutorEnvMap: JMap[Object, Object]): JavaSparkContext = {
-
val sparkConf = new SparkConf().setAppName(appName)
.setSparkHome(sparkHome)
@@ -381,83 +134,10 @@ private[r] object RRDD {
}
/**
- * Start a thread to print the process's stderr to ours
- */
- private def startStdoutThread(proc: Process): BufferedStreamThread = {
- val BUFFER_SIZE = 100
- val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE)
- thread.setDaemon(true)
- thread.start()
- thread
- }
-
- private def createRProcess(port: Int, script: String): BufferedStreamThread = {
- // "spark.sparkr.r.command" is deprecated and replaced by "spark.r.command",
- // but kept here for backward compatibility.
- val sparkConf = SparkEnv.get.conf
- var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript")
- rCommand = sparkConf.get("spark.r.command", rCommand)
-
- val rOptions = "--vanilla"
- val rLibDir = RUtils.sparkRPackagePath(isDriver = false)
- val rExecScript = rLibDir(0) + "/SparkR/worker/" + script
- val pb = new ProcessBuilder(Arrays.asList(rCommand, rOptions, rExecScript))
- // Unset the R_TESTS environment variable for workers.
- // This is set by R CMD check as startup.Rs
- // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R)
- // and confuses worker script which tries to load a non-existent file
- pb.environment().put("R_TESTS", "")
- pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(","))
- pb.environment().put("SPARKR_WORKER_PORT", port.toString)
- pb.redirectErrorStream(true) // redirect stderr into stdout
- val proc = pb.start()
- val errThread = startStdoutThread(proc)
- errThread
- }
-
- /**
- * ProcessBuilder used to launch worker R processes.
- */
- def createRWorker(port: Int): BufferedStreamThread = {
- val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true)
- if (!Utils.isWindows && useDaemon) {
- synchronized {
- if (daemonChannel == null) {
- // we expect one connections
- val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
- val daemonPort = serverSocket.getLocalPort
- errThread = createRProcess(daemonPort, "daemon.R")
- // the socket used to send out the input of task
- serverSocket.setSoTimeout(10000)
- val sock = serverSocket.accept()
- daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
- serverSocket.close()
- }
- try {
- daemonChannel.writeInt(port)
- daemonChannel.flush()
- } catch {
- case e: IOException =>
- // daemon process died
- daemonChannel.close()
- daemonChannel = null
- errThread = null
- // fail the current task, retry by scheduler
- throw e
- }
- errThread
- }
- } else {
- createRProcess(port, "worker.R")
- }
- }
-
- /**
* Create an RRDD given a sequence of byte arrays. Used to create RRDD when `parallelize` is
* called from R.
*/
def createRDDFromArray(jsc: JavaSparkContext, arr: Array[Array[Byte]]): JavaRDD[Array[Byte]] = {
JavaRDD.fromRDD(jsc.sc.parallelize(arr, arr.length))
}
-
}
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
new file mode 100644
index 0000000000..07d1fa2c4a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
@@ -0,0 +1,368 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+import java.io._
+import java.net.{InetAddress, ServerSocket}
+import java.util.Arrays
+
+import scala.io.Source
+import scala.util.Try
+
+import org.apache.spark._
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.Utils
+
+/**
+ * A helper class to run R UDFs in Spark.
+ */
+private[spark] class RRunner[U](
+ func: Array[Byte],
+ deserializer: String,
+ serializer: String,
+ packageNames: Array[Byte],
+ broadcastVars: Array[Broadcast[Object]],
+ numPartitions: Int = -1)
+ extends Logging {
+ private var bootTime: Double = _
+ private var dataStream: DataInputStream = _
+ val readData = numPartitions match {
+ case -1 =>
+ serializer match {
+ case SerializationFormats.STRING => readStringData _
+ case _ => readByteArrayData _
+ }
+ case _ => readShuffledData _
+ }
+
+ def compute(
+ inputIterator: Iterator[_],
+ partitionIndex: Int,
+ context: TaskContext): Iterator[U] = {
+ // Timing start
+ bootTime = System.currentTimeMillis / 1000.0
+
+ // we expect two connections
+ val serverSocket = new ServerSocket(0, 2, InetAddress.getByName("localhost"))
+ val listenPort = serverSocket.getLocalPort()
+
+ // The stdout/stderr is shared by multiple tasks, because we use one daemon
+ // to launch child process as worker.
+ val errThread = RRunner.createRWorker(listenPort)
+
+ // We use two sockets to separate input and output, then it's easy to manage
+ // the lifecycle of them to avoid deadlock.
+ // TODO: optimize it to use one socket
+
+ // the socket used to send out the input of task
+ serverSocket.setSoTimeout(10000)
+ val inSocket = serverSocket.accept()
+ startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex)
+
+ // the socket used to receive the output of task
+ val outSocket = serverSocket.accept()
+ val inputStream = new BufferedInputStream(outSocket.getInputStream)
+ dataStream = new DataInputStream(inputStream)
+ serverSocket.close()
+
+ try {
+ return new Iterator[U] {
+ def next(): U = {
+ val obj = _nextObj
+ if (hasNext) {
+ _nextObj = read()
+ }
+ obj
+ }
+
+ var _nextObj = read()
+
+ def hasNext(): Boolean = {
+ val hasMore = (_nextObj != null)
+ if (!hasMore) {
+ dataStream.close()
+ }
+ hasMore
+ }
+ }
+ } catch {
+ case e: Exception =>
+ throw new SparkException("R computation failed with\n " + errThread.getLines())
+ }
+ }
+
+ /**
+ * Start a thread to write RDD data to the R process.
+ */
+ private def startStdinThread(
+ output: OutputStream,
+ iter: Iterator[_],
+ partitionIndex: Int): Unit = {
+ val env = SparkEnv.get
+ val taskContext = TaskContext.get()
+ val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+ val stream = new BufferedOutputStream(output, bufferSize)
+
+ new Thread("writer for R") {
+ override def run(): Unit = {
+ try {
+ SparkEnv.set(env)
+ TaskContext.setTaskContext(taskContext)
+ val dataOut = new DataOutputStream(stream)
+ dataOut.writeInt(partitionIndex)
+
+ SerDe.writeString(dataOut, deserializer)
+ SerDe.writeString(dataOut, serializer)
+
+ dataOut.writeInt(packageNames.length)
+ dataOut.write(packageNames)
+
+ dataOut.writeInt(func.length)
+ dataOut.write(func)
+
+ dataOut.writeInt(broadcastVars.length)
+ broadcastVars.foreach { broadcast =>
+ // TODO(shivaram): Read a Long in R to avoid this cast
+ dataOut.writeInt(broadcast.id.toInt)
+ // TODO: Pass a byte array from R to avoid this cast ?
+ val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]]
+ dataOut.writeInt(broadcastByteArr.length)
+ dataOut.write(broadcastByteArr)
+ }
+
+ dataOut.writeInt(numPartitions)
+
+ if (!iter.hasNext) {
+ dataOut.writeInt(0)
+ } else {
+ dataOut.writeInt(1)
+ }
+
+ val printOut = new PrintStream(stream)
+
+ def writeElem(elem: Any): Unit = {
+ if (deserializer == SerializationFormats.BYTE) {
+ val elemArr = elem.asInstanceOf[Array[Byte]]
+ dataOut.writeInt(elemArr.length)
+ dataOut.write(elemArr)
+ } else if (deserializer == SerializationFormats.ROW) {
+ dataOut.write(elem.asInstanceOf[Array[Byte]])
+ } else if (deserializer == SerializationFormats.STRING) {
+ // write string(for StringRRDD)
+ // scalastyle:off println
+ printOut.println(elem)
+ // scalastyle:on println
+ }
+ }
+
+ for (elem <- iter) {
+ elem match {
+ case (key, value) =>
+ writeElem(key)
+ writeElem(value)
+ case _ =>
+ writeElem(elem)
+ }
+ }
+ stream.flush()
+ } catch {
+ // TODO: We should propagate this error to the task thread
+ case e: Exception =>
+ logError("R Writer thread got an exception", e)
+ } finally {
+ Try(output.close())
+ }
+ }
+ }.start()
+ }
+
+ private def read(): U = {
+ try {
+ val length = dataStream.readInt()
+
+ length match {
+ case SpecialLengths.TIMING_DATA =>
+ // Timing data from R worker
+ val boot = dataStream.readDouble - bootTime
+ val init = dataStream.readDouble
+ val broadcast = dataStream.readDouble
+ val input = dataStream.readDouble
+ val compute = dataStream.readDouble
+ val output = dataStream.readDouble
+ logInfo(
+ ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " +
+ "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " +
+ "total = %.3f s").format(
+ boot,
+ init,
+ broadcast,
+ input,
+ compute,
+ output,
+ boot + init + broadcast + input + compute + output))
+ read()
+ case length if length >= 0 =>
+ readData(length).asInstanceOf[U]
+ }
+ } catch {
+ case eof: EOFException =>
+ throw new SparkException("R worker exited unexpectedly (cranshed)", eof)
+ }
+ }
+
+ private def readShuffledData(length: Int): (Int, Array[Byte]) = {
+ length match {
+ case length if length == 2 =>
+ val hashedKey = dataStream.readInt()
+ val contentPairsLength = dataStream.readInt()
+ val contentPairs = new Array[Byte](contentPairsLength)
+ dataStream.readFully(contentPairs)
+ (hashedKey, contentPairs)
+ case _ => null
+ }
+ }
+
+ private def readByteArrayData(length: Int): Array[Byte] = {
+ length match {
+ case length if length > 0 =>
+ val obj = new Array[Byte](length)
+ dataStream.readFully(obj)
+ obj
+ case _ => null
+ }
+ }
+
+ private def readStringData(length: Int): String = {
+ length match {
+ case length if length > 0 =>
+ SerDe.readStringBytes(dataStream, length)
+ case _ => null
+ }
+ }
+}
+
+private object SpecialLengths {
+ val TIMING_DATA = -1
+}
+
+private[r] class BufferedStreamThread(
+ in: InputStream,
+ name: String,
+ errBufferSize: Int) extends Thread(name) with Logging {
+ val lines = new Array[String](errBufferSize)
+ var lineIdx = 0
+ override def run() {
+ for (line <- Source.fromInputStream(in).getLines) {
+ synchronized {
+ lines(lineIdx) = line
+ lineIdx = (lineIdx + 1) % errBufferSize
+ }
+ logInfo(line)
+ }
+ }
+
+ def getLines(): String = synchronized {
+ (0 until errBufferSize).filter { x =>
+ lines((x + lineIdx) % errBufferSize) != null
+ }.map { x =>
+ lines((x + lineIdx) % errBufferSize)
+ }.mkString("\n")
+ }
+}
+
+private[r] object RRunner {
+ // Because forking processes from Java is expensive, we prefer to launch
+ // a single R daemon (daemon.R) and tell it to fork new workers for our tasks.
+ // This daemon currently only works on UNIX-based systems now, so we should
+ // also fall back to launching workers (worker.R) directly.
+ private[this] var errThread: BufferedStreamThread = _
+ private[this] var daemonChannel: DataOutputStream = _
+
+ /**
+ * Start a thread to print the process's stderr to ours
+ */
+ private def startStdoutThread(proc: Process): BufferedStreamThread = {
+ val BUFFER_SIZE = 100
+ val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE)
+ thread.setDaemon(true)
+ thread.start()
+ thread
+ }
+
+ private def createRProcess(port: Int, script: String): BufferedStreamThread = {
+ // "spark.sparkr.r.command" is deprecated and replaced by "spark.r.command",
+ // but kept here for backward compatibility.
+ val sparkConf = SparkEnv.get.conf
+ var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript")
+ rCommand = sparkConf.get("spark.r.command", rCommand)
+
+ val rOptions = "--vanilla"
+ val rLibDir = RUtils.sparkRPackagePath(isDriver = false)
+ val rExecScript = rLibDir(0) + "/SparkR/worker/" + script
+ val pb = new ProcessBuilder(Arrays.asList(rCommand, rOptions, rExecScript))
+ // Unset the R_TESTS environment variable for workers.
+ // This is set by R CMD check as startup.Rs
+ // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R)
+ // and confuses worker script which tries to load a non-existent file
+ pb.environment().put("R_TESTS", "")
+ pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(","))
+ pb.environment().put("SPARKR_WORKER_PORT", port.toString)
+ pb.redirectErrorStream(true) // redirect stderr into stdout
+ val proc = pb.start()
+ val errThread = startStdoutThread(proc)
+ errThread
+ }
+
+ /**
+ * ProcessBuilder used to launch worker R processes.
+ */
+ def createRWorker(port: Int): BufferedStreamThread = {
+ val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true)
+ if (!Utils.isWindows && useDaemon) {
+ synchronized {
+ if (daemonChannel == null) {
+ // we expect one connections
+ val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
+ val daemonPort = serverSocket.getLocalPort
+ errThread = createRProcess(daemonPort, "daemon.R")
+ // the socket used to send out the input of task
+ serverSocket.setSoTimeout(10000)
+ val sock = serverSocket.accept()
+ daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
+ serverSocket.close()
+ }
+ try {
+ daemonChannel.writeInt(port)
+ daemonChannel.flush()
+ } catch {
+ case e: IOException =>
+ // daemon process died
+ daemonChannel.close()
+ daemonChannel = null
+ errThread = null
+ // fail the current task, retry by scheduler
+ throw e
+ }
+ errThread
+ }
+ } else {
+ createRProcess(port, "worker.R")
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index e5e6a9e4a8..632b0ae9c2 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -30,7 +30,7 @@ import org.apache.spark.io.CompressionCodec
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BlockId, BroadcastBlockId, StorageLevel}
import org.apache.spark.util.{ByteBufferInputStream, Utils}
-import org.apache.spark.util.io.{ByteArrayChunkOutputStream, ChunkedByteBuffer}
+import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
/**
* A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]].
@@ -228,12 +228,12 @@ private object TorrentBroadcast extends Logging {
blockSize: Int,
serializer: Serializer,
compressionCodec: Option[CompressionCodec]): Array[ByteBuffer] = {
- val bos = new ByteArrayChunkOutputStream(blockSize)
- val out: OutputStream = compressionCodec.map(c => c.compressedOutputStream(bos)).getOrElse(bos)
+ val cbbos = new ChunkedByteBufferOutputStream(blockSize, ByteBuffer.allocate)
+ val out = compressionCodec.map(c => c.compressedOutputStream(cbbos)).getOrElse(cbbos)
val ser = serializer.newInstance()
val serOut = ser.serializeStream(out)
serOut.writeObject[T](obj).close()
- bos.toArrays.map(ByteBuffer.wrap)
+ cbbos.toChunkedByteBuffer.getChunks()
}
def unBlockifyObject[T: ClassTag](
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index 06b7b388ca..cda9d38c6a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -74,13 +74,12 @@ class SparkHadoopUtil extends Logging {
}
}
+
/**
- * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop
- * subsystems.
+ * Appends S3-specific, spark.hadoop.*, and spark.buffer.size configurations to a Hadoop
+ * configuration.
*/
- def newConfiguration(conf: SparkConf): Configuration = {
- val hadoopConf = new Configuration()
-
+ def appendS3AndSparkHadoopConfigurations(conf: SparkConf, hadoopConf: Configuration): Unit = {
// Note: this null check is around more than just access to the "conf" object to maintain
// the behavior of the old implementation of this code, for backwards compatibility.
if (conf != null) {
@@ -106,7 +105,15 @@ class SparkHadoopUtil extends Logging {
val bufferSize = conf.get("spark.buffer.size", "65536")
hadoopConf.set("io.file.buffer.size", bufferSize)
}
+ }
+ /**
+ * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop
+ * subsystems.
+ */
+ def newConfiguration(conf: SparkConf): Configuration = {
+ val hadoopConf = new Configuration()
+ appendS3AndSparkHadoopConfigurations(conf, hadoopConf)
hadoopConf
}
@@ -145,10 +152,9 @@ class SparkHadoopUtil extends Logging {
val baselineBytesRead = f()
Some(() => f() - baselineBytesRead)
} catch {
- case e @ (_: NoSuchMethodException | _: ClassNotFoundException) => {
+ case e @ (_: NoSuchMethodException | _: ClassNotFoundException) =>
logDebug("Couldn't find method for retrieving thread-level FileSystem input data", e)
None
- }
}
}
@@ -167,10 +173,9 @@ class SparkHadoopUtil extends Logging {
val baselineBytesWritten = f()
Some(() => f() - baselineBytesWritten)
} catch {
- case e @ (_: NoSuchMethodException | _: ClassNotFoundException) => {
+ case e @ (_: NoSuchMethodException | _: ClassNotFoundException) =>
logDebug("Couldn't find method for retrieving thread-level FileSystem output data", e)
None
- }
}
}
@@ -308,7 +313,7 @@ class SparkHadoopUtil extends Logging {
*/
def substituteHadoopVariables(text: String, hadoopConf: Configuration): String = {
text match {
- case HADOOP_CONF_PATTERN(matched) => {
+ case HADOOP_CONF_PATTERN(matched) =>
logDebug(text + " matched " + HADOOP_CONF_PATTERN)
val key = matched.substring(13, matched.length() - 1) // remove ${hadoopconf- .. }
val eval = Option[String](hadoopConf.get(key))
@@ -323,11 +328,9 @@ class SparkHadoopUtil extends Logging {
// Continue to substitute more variables.
substituteHadoopVariables(eval.get, hadoopConf)
}
- }
- case _ => {
+ case _ =>
logDebug(text + " didn't match " + HADOOP_CONF_PATTERN)
text
- }
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 4049fc0c41..926e1ff7a8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -441,7 +441,6 @@ object SparkSubmit {
OptionAssigner(args.deployMode, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES,
sysProp = "spark.submit.deployMode"),
OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"),
- OptionAssigner(args.jars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars"),
OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars.ivy"),
OptionAssigner(args.driverMemory, ALL_CLUSTER_MGRS, CLIENT,
sysProp = "spark.driver.memory"),
@@ -452,27 +451,15 @@ object SparkSubmit {
OptionAssigner(args.driverExtraLibraryPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES,
sysProp = "spark.driver.extraLibraryPath"),
- // Yarn client only
- OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"),
+ // Yarn only
+ OptionAssigner(args.queue, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.queue"),
OptionAssigner(args.numExecutors, YARN, ALL_DEPLOY_MODES,
sysProp = "spark.executor.instances"),
- OptionAssigner(args.files, YARN, CLIENT, sysProp = "spark.yarn.dist.files"),
- OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"),
- OptionAssigner(args.principal, YARN, CLIENT, sysProp = "spark.yarn.principal"),
- OptionAssigner(args.keytab, YARN, CLIENT, sysProp = "spark.yarn.keytab"),
-
- // Yarn cluster only
- OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name"),
- OptionAssigner(args.driverMemory, YARN, CLUSTER, clOption = "--driver-memory"),
- OptionAssigner(args.driverCores, YARN, CLUSTER, clOption = "--driver-cores"),
- OptionAssigner(args.queue, YARN, CLUSTER, clOption = "--queue"),
- OptionAssigner(args.executorMemory, YARN, CLUSTER, clOption = "--executor-memory"),
- OptionAssigner(args.executorCores, YARN, CLUSTER, clOption = "--executor-cores"),
- OptionAssigner(args.files, YARN, CLUSTER, clOption = "--files"),
- OptionAssigner(args.archives, YARN, CLUSTER, clOption = "--archives"),
- OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"),
- OptionAssigner(args.principal, YARN, CLUSTER, clOption = "--principal"),
- OptionAssigner(args.keytab, YARN, CLUSTER, clOption = "--keytab"),
+ OptionAssigner(args.jars, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.jars"),
+ OptionAssigner(args.files, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.files"),
+ OptionAssigner(args.archives, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.archives"),
+ OptionAssigner(args.principal, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.principal"),
+ OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.keytab"),
// Other options
OptionAssigner(args.executorCores, STANDALONE | YARN, ALL_DEPLOY_MODES,
@@ -483,10 +470,11 @@ object SparkSubmit {
sysProp = "spark.cores.max"),
OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES,
sysProp = "spark.files"),
- OptionAssigner(args.jars, STANDALONE | MESOS, CLUSTER, sysProp = "spark.jars"),
- OptionAssigner(args.driverMemory, STANDALONE | MESOS, CLUSTER,
+ OptionAssigner(args.jars, LOCAL, CLIENT, sysProp = "spark.jars"),
+ OptionAssigner(args.jars, STANDALONE | MESOS, ALL_DEPLOY_MODES, sysProp = "spark.jars"),
+ OptionAssigner(args.driverMemory, STANDALONE | MESOS | YARN, CLUSTER,
sysProp = "spark.driver.memory"),
- OptionAssigner(args.driverCores, STANDALONE | MESOS, CLUSTER,
+ OptionAssigner(args.driverCores, STANDALONE | MESOS | YARN, CLUSTER,
sysProp = "spark.driver.cores"),
OptionAssigner(args.supervise.toString, STANDALONE | MESOS, CLUSTER,
sysProp = "spark.driver.supervise"),
@@ -550,6 +538,10 @@ object SparkSubmit {
if (args.isPython) {
sysProps.put("spark.yarn.isPython", "true")
}
+
+ if (args.pyFiles != null) {
+ sysProps("spark.submit.pyFiles") = args.pyFiles
+ }
}
// assure a keytab is available from any place in a JVM
@@ -576,9 +568,6 @@ object SparkSubmit {
childMainClass = "org.apache.spark.deploy.yarn.Client"
if (args.isPython) {
childArgs += ("--primary-py-file", args.primaryResource)
- if (args.pyFiles != null) {
- childArgs += ("--py-files", args.pyFiles)
- }
childArgs += ("--class", "org.apache.spark.deploy.PythonRunner")
} else if (args.isR) {
val mainFile = new Path(args.primaryResource).getName
@@ -627,7 +616,8 @@ object SparkSubmit {
"spark.jars",
"spark.files",
"spark.yarn.dist.files",
- "spark.yarn.dist.archives")
+ "spark.yarn.dist.archives",
+ "spark.yarn.dist.jars")
pathConfigs.foreach { config =>
// Replace old URIs with resolved URIs, if they exist
sysProps.get(config).foreach { oldValue =>
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala
index e584952a9a..94506a0cbb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala
@@ -33,7 +33,8 @@ private[spark] trait AppClientListener {
/** An application death is an unrecoverable failure condition. */
def dead(reason: String): Unit
- def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int)
+ def executorAdded(
+ fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit
def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
index 70f21fbe0d..52e2854961 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
@@ -32,8 +32,8 @@ trait LeaderElectionAgent {
@DeveloperApi
trait LeaderElectable {
- def electedLeader()
- def revokedLeadership()
+ def electedLeader(): Unit
+ def revokedLeadership(): Unit
}
/** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 01901bbf85..b443e8f051 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -217,7 +217,7 @@ private[deploy] class Master(
}
override def receive: PartialFunction[Any, Unit] = {
- case ElectedLeader => {
+ case ElectedLeader =>
val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv)
state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) {
RecoveryState.ALIVE
@@ -233,16 +233,14 @@ private[deploy] class Master(
}
}, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS)
}
- }
case CompleteRecovery => completeRecovery()
- case RevokedLeadership => {
+ case RevokedLeadership =>
logError("Leadership has been revoked -- master shutting down.")
System.exit(0)
- }
- case RegisterApplication(description, driver) => {
+ case RegisterApplication(description, driver) =>
// TODO Prevent repeated registrations from some driver
if (state == RecoveryState.STANDBY) {
// ignore, don't send response
@@ -255,12 +253,11 @@ private[deploy] class Master(
driver.send(RegisteredApplication(app.id, self))
schedule()
}
- }
- case ExecutorStateChanged(appId, execId, state, message, exitStatus) => {
+ case ExecutorStateChanged(appId, execId, state, message, exitStatus) =>
val execOption = idToApp.get(appId).flatMap(app => app.executors.get(execId))
execOption match {
- case Some(exec) => {
+ case Some(exec) =>
val appInfo = idToApp(appId)
val oldState = exec.state
exec.state = state
@@ -298,22 +295,19 @@ private[deploy] class Master(
}
}
}
- }
case None =>
logWarning(s"Got status update for unknown executor $appId/$execId")
}
- }
- case DriverStateChanged(driverId, state, exception) => {
+ case DriverStateChanged(driverId, state, exception) =>
state match {
case DriverState.ERROR | DriverState.FINISHED | DriverState.KILLED | DriverState.FAILED =>
removeDriver(driverId, state, exception)
case _ =>
throw new Exception(s"Received unexpected state update for driver $driverId: $state")
}
- }
- case Heartbeat(workerId, worker) => {
+ case Heartbeat(workerId, worker) =>
idToWorker.get(workerId) match {
case Some(workerInfo) =>
workerInfo.lastHeartbeat = System.currentTimeMillis()
@@ -327,9 +321,8 @@ private[deploy] class Master(
" This worker was never registered, so ignoring the heartbeat.")
}
}
- }
- case MasterChangeAcknowledged(appId) => {
+ case MasterChangeAcknowledged(appId) =>
idToApp.get(appId) match {
case Some(app) =>
logInfo("Application has been re-registered: " + appId)
@@ -339,9 +332,8 @@ private[deploy] class Master(
}
if (canCompleteRecovery) { completeRecovery() }
- }
- case WorkerSchedulerStateResponse(workerId, executors, driverIds) => {
+ case WorkerSchedulerStateResponse(workerId, executors, driverIds) =>
idToWorker.get(workerId) match {
case Some(worker) =>
logInfo("Worker has been re-registered: " + workerId)
@@ -367,7 +359,6 @@ private[deploy] class Master(
}
if (canCompleteRecovery) { completeRecovery() }
- }
case WorkerLatestState(workerId, executors, driverIds) =>
idToWorker.get(workerId) match {
@@ -397,9 +388,8 @@ private[deploy] class Master(
logInfo(s"Received unregister request from application $applicationId")
idToApp.get(applicationId).foreach(finishApplication)
- case CheckForWorkerTimeOut => {
+ case CheckForWorkerTimeOut =>
timeOutDeadWorkers()
- }
case AttachCompletedRebuildUI(appId) =>
// An asyncRebuildSparkUI has completed, so need to attach to master webUi
@@ -408,7 +398,7 @@ private[deploy] class Master(
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case RegisterWorker(
- id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl) => {
+ id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl) =>
logInfo("Registering worker %s:%d with %d cores, %s RAM".format(
workerHost, workerPort, cores, Utils.megabytesToString(memory)))
if (state == RecoveryState.STANDBY) {
@@ -430,9 +420,8 @@ private[deploy] class Master(
+ workerAddress))
}
}
- }
- case RequestSubmitDriver(description) => {
+ case RequestSubmitDriver(description) =>
if (state != RecoveryState.ALIVE) {
val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " +
"Can only accept driver submissions in ALIVE state."
@@ -451,9 +440,8 @@ private[deploy] class Master(
context.reply(SubmitDriverResponse(self, true, Some(driver.id),
s"Driver successfully submitted as ${driver.id}"))
}
- }
- case RequestKillDriver(driverId) => {
+ case RequestKillDriver(driverId) =>
if (state != RecoveryState.ALIVE) {
val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " +
s"Can only kill drivers in ALIVE state."
@@ -484,9 +472,8 @@ private[deploy] class Master(
context.reply(KillDriverResponse(self, driverId, success = false, msg))
}
}
- }
- case RequestDriverStatus(driverId) => {
+ case RequestDriverStatus(driverId) =>
if (state != RecoveryState.ALIVE) {
val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " +
"Can only request driver status in ALIVE state."
@@ -501,18 +488,15 @@ private[deploy] class Master(
context.reply(DriverStatusResponse(found = false, None, None, None, None))
}
}
- }
- case RequestMasterState => {
+ case RequestMasterState =>
context.reply(MasterStateResponse(
address.host, address.port, restServerBoundPort,
workers.toArray, apps.toArray, completedApps.toArray,
drivers.toArray, completedDrivers.toArray, state))
- }
- case BoundPortsRequest => {
+ case BoundPortsRequest =>
context.reply(BoundPortsResponse(address.port, webUi.boundPort, restServerBoundPort))
- }
case RequestExecutors(appId, requestedTotal) =>
context.reply(handleRequestExecutors(appId, requestedTotal))
@@ -859,10 +843,10 @@ private[deploy] class Master(
addressToApp -= app.driver.address
if (completedApps.size >= RETAINED_APPLICATIONS) {
val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1)
- completedApps.take(toRemove).foreach( a => {
+ completedApps.take(toRemove).foreach { a =>
Option(appIdToUI.remove(a.id)).foreach { ui => webUi.detachSparkUI(ui) }
applicationMetricsSystem.removeSource(a.appSource)
- })
+ }
completedApps.trimStart(toRemove)
}
completedApps += app // Remember it in our history
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
index 9cd7458ba0..585e0839d0 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
@@ -78,7 +78,7 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) {
case ("--help") :: tail =>
printUsageAndExit(0)
- case Nil => {}
+ case Nil => // No-op
case _ =>
printUsageAndExit(1)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
index dddf2be57e..b30bc821b7 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
@@ -40,12 +40,12 @@ abstract class PersistenceEngine {
* Defines how the object is serialized and persisted. Implementation will
* depend on the store used.
*/
- def persist(name: String, obj: Object)
+ def persist(name: String, obj: Object): Unit
/**
* Defines how the object referred by its name is removed from the store.
*/
- def unpersist(name: String)
+ def unpersist(name: String): Unit
/**
* Gives all objects, matching a prefix. This defines how objects are
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
index 79f77212fe..af850e4871 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
@@ -70,11 +70,10 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer
try {
Some(serializer.newInstance().deserialize[T](ByteBuffer.wrap(fileData)))
} catch {
- case e: Exception => {
+ case e: Exception =>
logWarning("Exception while reading persisted file, deleting", e)
zk.delete().forPath(WORKING_DIR + "/" + filename)
None
- }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala
index b97805a28b..11e13441ee 100644
--- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala
@@ -76,14 +76,13 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf:
case ("--help") :: tail =>
printUsageAndExit(0)
- case Nil => {
+ case Nil =>
if (masterUrl == null) {
// scalastyle:off println
System.err.println("--master is required")
// scalastyle:on println
printUsageAndExit(1)
}
- }
case _ =>
printUsageAndExit(1)
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
index a4efafcb27..cba4aaffe2 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
@@ -29,7 +29,7 @@ import org.apache.spark.launcher.WorkerCommandBuilder
import org.apache.spark.util.Utils
/**
- ** Utilities for running commands with the spark classpath.
+ * Utilities for running commands with the spark classpath.
*/
private[deploy]
object CommandUtils extends Logging {
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
index 9c6bc5c62f..aad2e91b25 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
@@ -218,7 +218,7 @@ private[deploy] class DriverRunner(
}
private[deploy] trait Sleeper {
- def sleep(seconds: Int)
+ def sleep(seconds: Int): Unit
}
// Needed because ProcessBuilder is a final class and cannot be mocked
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index f9c92c3bb9..06066248ea 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -179,16 +179,14 @@ private[deploy] class ExecutorRunner(
val message = "Command exited with code " + exitCode
worker.send(ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode)))
} catch {
- case interrupted: InterruptedException => {
+ case interrupted: InterruptedException =>
logInfo("Runner thread for executor " + fullId + " interrupted")
state = ExecutorState.KILLED
killProcess(None)
- }
- case e: Exception => {
+ case e: Exception =>
logError("Error running executor", e)
state = ExecutorState.FAILED
killProcess(Some(e.toString))
- }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 1b7637a39c..449beb0811 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -480,7 +480,7 @@ private[deploy] class Worker(
memoryUsed += memory_
sendToMaster(ExecutorStateChanged(appId, execId, manager.state, None, None))
} catch {
- case e: Exception => {
+ case e: Exception =>
logError(s"Failed to launch executor $appId/$execId for ${appDesc.name}.", e)
if (executors.contains(appId + "/" + execId)) {
executors(appId + "/" + execId).kill()
@@ -488,7 +488,6 @@ private[deploy] class Worker(
}
sendToMaster(ExecutorStateChanged(appId, execId, ExecutorState.FAILED,
Some(e.toString), None))
- }
}
}
@@ -509,7 +508,7 @@ private[deploy] class Worker(
}
}
- case LaunchDriver(driverId, driverDesc) => {
+ case LaunchDriver(driverId, driverDesc) =>
logInfo(s"Asked to launch driver $driverId")
val driver = new DriverRunner(
conf,
@@ -525,9 +524,8 @@ private[deploy] class Worker(
coresUsed += driverDesc.cores
memoryUsed += driverDesc.mem
- }
- case KillDriver(driverId) => {
+ case KillDriver(driverId) =>
logInfo(s"Asked to kill driver $driverId")
drivers.get(driverId) match {
case Some(runner) =>
@@ -535,11 +533,9 @@ private[deploy] class Worker(
case None =>
logError(s"Asked to kill unknown driver $driverId")
}
- }
- case driverStateChanged @ DriverStateChanged(driverId, state, exception) => {
+ case driverStateChanged @ DriverStateChanged(driverId, state, exception) =>
handleDriverStateChanged(driverStateChanged)
- }
case ReregisterWithMaster =>
reregisterWithMaster()
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
index 391eb41190..777020d4d5 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
@@ -165,12 +165,11 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) {
}
// scalastyle:on classforname
} catch {
- case e: Exception => {
+ case e: Exception =>
totalMb = 2*1024
// scalastyle:off println
System.out.println("Failed to get total physical memory. Using " + totalMb + " MB")
// scalastyle:on println
- }
}
// Leave out 1 GB for the operating system, but don't return a negative memory size
math.max(totalMb - 1024, Utils.DEFAULT_DRIVER_MEM_MB)
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
index 6500cab73b..e75c0cec4a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
@@ -107,20 +107,18 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with
}
val content =
- <html>
- <body>
- {linkToMaster}
- <div>
- <div style="float:left; margin-right:10px">{backButton}</div>
- <div style="float:left;">{range}</div>
- <div style="float:right; margin-left:10px">{nextButton}</div>
- </div>
- <br />
- <div style="height:500px; overflow:auto; padding:5px;">
- <pre>{logText}</pre>
- </div>
- </body>
- </html>
+ <div>
+ {linkToMaster}
+ <div>
+ <div style="float:left; margin-right:10px">{backButton}</div>
+ <div style="float:left;">{range}</div>
+ <div style="float:right; margin-left:10px">{nextButton}</div>
+ </div>
+ <br />
+ <div style="height:500px; overflow:auto; padding:5px;">
+ <pre>{logText}</pre>
+ </div>
+ </div>
UIUtils.basicSparkPage(content, logType + " log page for " + pageName)
}
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 320a20033d..71b4ad160d 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -57,16 +57,14 @@ private[spark] class CoarseGrainedExecutorBackend(
rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref =>
// This is a very fast action so we can use "ThreadUtils.sameThread"
driver = Some(ref)
- ref.ask[RegisterExecutorResponse](RegisterExecutor(executorId, self, cores, extractLogUrls))
+ ref.ask[Boolean](RegisterExecutor(executorId, self, cores, extractLogUrls))
}(ThreadUtils.sameThread).onComplete {
// This is a very fast action so we can use "ThreadUtils.sameThread"
- case Success(msg) => Utils.tryLogNonFatalError {
- Option(self).foreach(_.send(msg)) // msg must be RegisterExecutorResponse
- }
- case Failure(e) => {
+ case Success(msg) =>
+ // Always receive `true`. Just ignore it
+ case Failure(e) =>
logError(s"Cannot register with driver: $driverUrl", e)
System.exit(1)
- }
}(ThreadUtils.sameThread)
}
@@ -113,9 +111,15 @@ private[spark] class CoarseGrainedExecutorBackend(
case Shutdown =>
stopping.set(true)
- executor.stop()
- stop()
- rpcEnv.shutdown()
+ new Thread("CoarseGrainedExecutorBackend-stop-executor") {
+ override def run(): Unit = {
+ // executor.stop() will call `SparkEnv.stop()` which waits until RpcEnv stops totally.
+ // However, if `executor.stop()` runs in some thread of RpcEnv, RpcEnv won't be able to
+ // stop until `executor.stop()` returns, which becomes a dead-lock (See SPARK-14180).
+ // Therefore, we put this line in a new thread.
+ executor.stop()
+ }
+ }.start()
}
override def onDisconnected(remoteAddress: RpcAddress): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 3201463b8c..9f94fdef24 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -21,6 +21,7 @@ import java.io.{File, NotSerializableException}
import java.lang.management.ManagementFactory
import java.net.URL
import java.nio.ByteBuffer
+import java.util.Properties
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import scala.collection.JavaConverters._
@@ -206,9 +207,16 @@ private[spark] class Executor(
startGCTime = computeTotalGcTime()
try {
- val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
+ val (taskFiles, taskJars, taskProps, taskBytes) =
+ Task.deserializeWithDependencies(serializedTask)
+
+ // Must be set before updateDependencies() is called, in case fetching dependencies
+ // requires access to properties contained within (e.g. for access control).
+ Executor.taskDeserializationProps.set(taskProps)
+
updateDependencies(taskFiles, taskJars)
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
+ task.localProperties = taskProps
task.setTaskMemoryManager(taskMemoryManager)
// If this task has been killed before we deserialized it, let's quit now. Otherwise,
@@ -254,7 +262,7 @@ private[spark] class Executor(
if (conf.getBoolean("spark.storage.exceptionOnPinLeak", false) && !threwException) {
throw new SparkException(errMsg)
} else {
- logError(errMsg)
+ logWarning(errMsg)
}
}
}
@@ -321,7 +329,7 @@ private[spark] class Executor(
logInfo(s"Executor killed $taskName (TID $taskId)")
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
- case cDE: CommitDeniedException =>
+ case CausedBy(cDE: CommitDeniedException) =>
val reason = cDE.toTaskEndReason
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
@@ -506,3 +514,10 @@ private[spark] class Executor(
heartbeater.scheduleAtFixedRate(heartbeatTask, initialDelay, intervalMs, TimeUnit.MILLISECONDS)
}
}
+
+private[spark] object Executor {
+ // This is reserved for internal use by components that need to read task properties before a
+ // task is fully deserialized. When possible, the TaskContext.getLocalProperty call should be
+ // used instead.
+ val taskDeserializationProps: ThreadLocal[Properties] = new ThreadLocal[Properties]
+}
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala
index e07cb31cbe..7153323d01 100644
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala
@@ -25,6 +25,6 @@ import org.apache.spark.TaskState.TaskState
* A pluggable interface used by the Executor to send updates to the cluster scheduler.
*/
private[spark] trait ExecutorBackend {
- def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer)
+ def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala
index 6d30d3c76a..83e11c5e23 100644
--- a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala
@@ -81,35 +81,9 @@ class InputMetrics private (
*/
def readMethod: DataReadMethod.Value = DataReadMethod.withName(_readMethod.localValue)
- // Once incBytesRead & intRecordsRead is ready to be removed from the public API
- // we can remove the internal versions and make the previous public API private.
- // This has been done to suppress warnings when building.
- @deprecated("incrementing input metrics is for internal use only", "2.0.0")
- def incBytesRead(v: Long): Unit = _bytesRead.add(v)
- private[spark] def incBytesReadInternal(v: Long): Unit = _bytesRead.add(v)
- @deprecated("incrementing input metrics is for internal use only", "2.0.0")
- def incRecordsRead(v: Long): Unit = _recordsRead.add(v)
- private[spark] def incRecordsReadInternal(v: Long): Unit = _recordsRead.add(v)
+ private[spark] def incBytesRead(v: Long): Unit = _bytesRead.add(v)
+ private[spark] def incRecordsRead(v: Long): Unit = _recordsRead.add(v)
private[spark] def setBytesRead(v: Long): Unit = _bytesRead.setValue(v)
- private[spark] def setReadMethod(v: DataReadMethod.Value): Unit =
- _readMethod.setValue(v.toString)
+ private[spark] def setReadMethod(v: DataReadMethod.Value): Unit = _readMethod.setValue(v.toString)
}
-
-/**
- * Deprecated methods to preserve case class matching behavior before Spark 2.0.
- */
-object InputMetrics {
-
- @deprecated("matching on InputMetrics will not be supported in the future", "2.0.0")
- def apply(readMethod: DataReadMethod.Value): InputMetrics = {
- val im = new InputMetrics
- im.setReadMethod(readMethod)
- im
- }
-
- @deprecated("matching on InputMetrics will not be supported in the future", "2.0.0")
- def unapply(input: InputMetrics): Option[DataReadMethod.Value] = {
- Some(input.readMethod)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala
index 0b37d559c7..93f953846f 100644
--- a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala
@@ -52,18 +52,6 @@ class OutputMetrics private (
}
/**
- * Create a new [[OutputMetrics]] that is not associated with any particular task.
- *
- * This is only used for preserving matching behavior on [[OutputMetrics]], which used to be
- * a case class before Spark 2.0. Once we remove support for matching on [[OutputMetrics]]
- * we can remove this constructor as well.
- */
- private[executor] def this() {
- this(InternalAccumulator.createOutputAccums()
- .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]])
- }
-
- /**
* Total number of bytes written.
*/
def bytesWritten: Long = _bytesWritten.localValue
@@ -84,21 +72,3 @@ class OutputMetrics private (
_writeMethod.setValue(v.toString)
}
-
-/**
- * Deprecated methods to preserve case class matching behavior before Spark 2.0.
- */
-object OutputMetrics {
-
- @deprecated("matching on OutputMetrics will not be supported in the future", "2.0.0")
- def apply(writeMethod: DataWriteMethod.Value): OutputMetrics = {
- val om = new OutputMetrics
- om.setWriteMethod(writeMethod)
- om
- }
-
- @deprecated("matching on OutputMetrics will not be supported in the future", "2.0.0")
- def unapply(output: OutputMetrics): Option[DataWriteMethod.Value] = {
- Some(output.writeMethod)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala
index 50bb645d97..71a24770b5 100644
--- a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala
@@ -116,4 +116,25 @@ class ShuffleReadMetrics private (
private[spark] def setFetchWaitTime(v: Long): Unit = _fetchWaitTime.setValue(v)
private[spark] def setRecordsRead(v: Long): Unit = _recordsRead.setValue(v)
+ /**
+ * Resets the value of the current metrics (`this`) and and merges all the independent
+ * [[ShuffleReadMetrics]] into `this`.
+ */
+ private[spark] def setMergeValues(metrics: Seq[ShuffleReadMetrics]): Unit = {
+ _remoteBlocksFetched.setValue(_remoteBlocksFetched.zero)
+ _localBlocksFetched.setValue(_localBlocksFetched.zero)
+ _remoteBytesRead.setValue(_remoteBytesRead.zero)
+ _localBytesRead.setValue(_localBytesRead.zero)
+ _fetchWaitTime.setValue(_fetchWaitTime.zero)
+ _recordsRead.setValue(_recordsRead.zero)
+ metrics.foreach { metric =>
+ _remoteBlocksFetched.add(metric.remoteBlocksFetched)
+ _localBlocksFetched.add(metric.localBlocksFetched)
+ _remoteBytesRead.add(metric.remoteBytesRead)
+ _localBytesRead.add(metric.localBytesRead)
+ _fetchWaitTime.add(metric.fetchWaitTime)
+ _recordsRead.add(metric.recordsRead)
+ }
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 02219a84ab..bda2a91d9d 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -139,16 +139,6 @@ class TaskMetrics private[spark] (initialAccums: Seq[Accumulator[_]]) extends Se
*/
def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = _updatedBlockStatuses.localValue
- @deprecated("use updatedBlockStatuses instead", "2.0.0")
- def updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = {
- if (updatedBlockStatuses.nonEmpty) Some(updatedBlockStatuses) else None
- }
-
- @deprecated("setting updated blocks is not allowed", "2.0.0")
- def updatedBlocks_=(blocks: Option[Seq[(BlockId, BlockStatus)]]): Unit = {
- blocks.foreach(setUpdatedBlockStatuses)
- }
-
// Setters and increment-ers
private[spark] def setExecutorDeserializeTime(v: Long): Unit =
_executorDeserializeTime.setValue(v)
@@ -225,11 +215,6 @@ class TaskMetrics private[spark] (initialAccums: Seq[Accumulator[_]]) extends Se
*/
def outputMetrics: Option[OutputMetrics] = _outputMetrics
- @deprecated("setting OutputMetrics is for internal use only", "2.0.0")
- def outputMetrics_=(om: Option[OutputMetrics]): Unit = {
- _outputMetrics = om
- }
-
/**
* Get or create a new [[OutputMetrics]] associated with this task.
*/
@@ -285,12 +270,7 @@ class TaskMetrics private[spark] (initialAccums: Seq[Accumulator[_]]) extends Se
private[spark] def mergeShuffleReadMetrics(): Unit = synchronized {
if (tempShuffleReadMetrics.nonEmpty) {
val metrics = new ShuffleReadMetrics(initialAccumsMap)
- metrics.setRemoteBlocksFetched(tempShuffleReadMetrics.map(_.remoteBlocksFetched).sum)
- metrics.setLocalBlocksFetched(tempShuffleReadMetrics.map(_.localBlocksFetched).sum)
- metrics.setFetchWaitTime(tempShuffleReadMetrics.map(_.fetchWaitTime).sum)
- metrics.setRemoteBytesRead(tempShuffleReadMetrics.map(_.remoteBytesRead).sum)
- metrics.setLocalBytesRead(tempShuffleReadMetrics.map(_.localBytesRead).sum)
- metrics.setRecordsRead(tempShuffleReadMetrics.map(_.recordsRead).sum)
+ metrics.setMergeValues(tempShuffleReadMetrics)
_shuffleReadMetrics = Some(metrics)
}
}
@@ -306,11 +286,6 @@ class TaskMetrics private[spark] (initialAccums: Seq[Accumulator[_]]) extends Se
*/
def shuffleWriteMetrics: Option[ShuffleWriteMetrics] = _shuffleWriteMetrics
- @deprecated("setting ShuffleWriteMetrics is for internal use only", "2.0.0")
- def shuffleWriteMetrics_=(swm: Option[ShuffleWriteMetrics]): Unit = {
- _shuffleWriteMetrics = swm
- }
-
/**
* Get or create a new [[ShuffleWriteMetrics]] associated with this task.
*/
diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala
index 770b43697a..5d50e3851a 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala
@@ -85,10 +85,12 @@ private[spark] class TypedConfigBuilder[T](
this(parent, converter, Option(_).map(_.toString).orNull)
}
+ /** Apply a transformation to the user-provided values of the config entry. */
def transform(fn: T => T): TypedConfigBuilder[T] = {
new TypedConfigBuilder(parent, s => fn(converter(s)), stringConverter)
}
+ /** Check that user-provided values for the config match a pre-defined set. */
def checkValues(validValues: Set[T]): TypedConfigBuilder[T] = {
transform { v =>
if (!validValues.contains(v)) {
@@ -99,30 +101,38 @@ private[spark] class TypedConfigBuilder[T](
}
}
+ /** Turns the config entry into a sequence of values of the underlying type. */
def toSequence: TypedConfigBuilder[Seq[T]] = {
new TypedConfigBuilder(parent, stringToSeq(_, converter), seqToString(_, stringConverter))
}
- /** Creates a [[ConfigEntry]] that does not require a default value. */
- def optional: OptionalConfigEntry[T] = {
- new OptionalConfigEntry[T](parent.key, converter, stringConverter, parent._doc, parent._public)
+ /** Creates a [[ConfigEntry]] that does not have a default value. */
+ def createOptional: OptionalConfigEntry[T] = {
+ val entry = new OptionalConfigEntry[T](parent.key, converter, stringConverter, parent._doc,
+ parent._public)
+ parent._onCreate.foreach(_(entry))
+ entry
}
/** Creates a [[ConfigEntry]] that has a default value. */
- def withDefault(default: T): ConfigEntry[T] = {
+ def createWithDefault(default: T): ConfigEntry[T] = {
val transformedDefault = converter(stringConverter(default))
- new ConfigEntryWithDefault[T](parent.key, transformedDefault, converter, stringConverter,
- parent._doc, parent._public)
+ val entry = new ConfigEntryWithDefault[T](parent.key, transformedDefault, converter,
+ stringConverter, parent._doc, parent._public)
+ parent._onCreate.foreach(_(entry))
+ entry
}
/**
* Creates a [[ConfigEntry]] that has a default value. The default value is provided as a
* [[String]] and must be a valid value for the entry.
*/
- def withDefaultString(default: String): ConfigEntry[T] = {
+ def createWithDefaultString(default: String): ConfigEntry[T] = {
val typedDefault = converter(default)
- new ConfigEntryWithDefault[T](parent.key, typedDefault, converter, stringConverter, parent._doc,
- parent._public)
+ val entry = new ConfigEntryWithDefault[T](parent.key, typedDefault, converter, stringConverter,
+ parent._doc, parent._public)
+ parent._onCreate.foreach(_(entry))
+ entry
}
}
@@ -136,10 +146,11 @@ private[spark] case class ConfigBuilder(key: String) {
import ConfigHelpers._
- var _public = true
- var _doc = ""
+ private[config] var _public = true
+ private[config] var _doc = ""
+ private[config] var _onCreate: Option[ConfigEntry[_] => Unit] = None
- def internal: ConfigBuilder = {
+ def internal(): ConfigBuilder = {
_public = false
this
}
@@ -149,6 +160,15 @@ private[spark] case class ConfigBuilder(key: String) {
this
}
+ /**
+ * Registers a callback for when the config entry is finally instantiated. Currently used by
+ * SQLConf to keep track of SQL configuration entries.
+ */
+ def onCreate(callback: ConfigEntry[_] => Unit): ConfigBuilder = {
+ _onCreate = Option(callback)
+ this
+ }
+
def intConf: TypedConfigBuilder[Int] = {
new TypedConfigBuilder(this, toNumber(_, _.toInt, key, "int"))
}
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index f2f20b3207..94b50ee065 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -18,59 +18,75 @@
package org.apache.spark.internal
import org.apache.spark.launcher.SparkLauncher
+import org.apache.spark.network.util.ByteUnit
package object config {
private[spark] val DRIVER_CLASS_PATH =
- ConfigBuilder(SparkLauncher.DRIVER_EXTRA_CLASSPATH).stringConf.optional
+ ConfigBuilder(SparkLauncher.DRIVER_EXTRA_CLASSPATH).stringConf.createOptional
private[spark] val DRIVER_JAVA_OPTIONS =
- ConfigBuilder(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS).stringConf.optional
+ ConfigBuilder(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS).stringConf.createOptional
private[spark] val DRIVER_LIBRARY_PATH =
- ConfigBuilder(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH).stringConf.optional
+ ConfigBuilder(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH).stringConf.createOptional
private[spark] val DRIVER_USER_CLASS_PATH_FIRST =
- ConfigBuilder("spark.driver.userClassPathFirst").booleanConf.withDefault(false)
+ ConfigBuilder("spark.driver.userClassPathFirst").booleanConf.createWithDefault(false)
+
+ private[spark] val DRIVER_MEMORY = ConfigBuilder("spark.driver.memory")
+ .bytesConf(ByteUnit.MiB)
+ .createWithDefaultString("1g")
private[spark] val EXECUTOR_CLASS_PATH =
- ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_CLASSPATH).stringConf.optional
+ ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_CLASSPATH).stringConf.createOptional
private[spark] val EXECUTOR_JAVA_OPTIONS =
- ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_JAVA_OPTIONS).stringConf.optional
+ ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_JAVA_OPTIONS).stringConf.createOptional
private[spark] val EXECUTOR_LIBRARY_PATH =
- ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_LIBRARY_PATH).stringConf.optional
+ ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_LIBRARY_PATH).stringConf.createOptional
private[spark] val EXECUTOR_USER_CLASS_PATH_FIRST =
- ConfigBuilder("spark.executor.userClassPathFirst").booleanConf.withDefault(false)
+ ConfigBuilder("spark.executor.userClassPathFirst").booleanConf.createWithDefault(false)
+
+ private[spark] val EXECUTOR_MEMORY = ConfigBuilder("spark.executor.memory")
+ .bytesConf(ByteUnit.MiB)
+ .createWithDefaultString("1g")
- private[spark] val IS_PYTHON_APP = ConfigBuilder("spark.yarn.isPython").internal
- .booleanConf.withDefault(false)
+ private[spark] val IS_PYTHON_APP = ConfigBuilder("spark.yarn.isPython").internal()
+ .booleanConf.createWithDefault(false)
- private[spark] val CPUS_PER_TASK = ConfigBuilder("spark.task.cpus").intConf.withDefault(1)
+ private[spark] val CPUS_PER_TASK = ConfigBuilder("spark.task.cpus").intConf.createWithDefault(1)
private[spark] val DYN_ALLOCATION_MIN_EXECUTORS =
- ConfigBuilder("spark.dynamicAllocation.minExecutors").intConf.withDefault(0)
+ ConfigBuilder("spark.dynamicAllocation.minExecutors").intConf.createWithDefault(0)
private[spark] val DYN_ALLOCATION_INITIAL_EXECUTORS =
ConfigBuilder("spark.dynamicAllocation.initialExecutors")
.fallbackConf(DYN_ALLOCATION_MIN_EXECUTORS)
private[spark] val DYN_ALLOCATION_MAX_EXECUTORS =
- ConfigBuilder("spark.dynamicAllocation.maxExecutors").intConf.withDefault(Int.MaxValue)
+ ConfigBuilder("spark.dynamicAllocation.maxExecutors").intConf.createWithDefault(Int.MaxValue)
private[spark] val SHUFFLE_SERVICE_ENABLED =
- ConfigBuilder("spark.shuffle.service.enabled").booleanConf.withDefault(false)
+ ConfigBuilder("spark.shuffle.service.enabled").booleanConf.createWithDefault(false)
private[spark] val KEYTAB = ConfigBuilder("spark.yarn.keytab")
.doc("Location of user's keytab.")
- .stringConf.optional
+ .stringConf.createOptional
private[spark] val PRINCIPAL = ConfigBuilder("spark.yarn.principal")
.doc("Name of the Kerberos principal.")
- .stringConf.optional
+ .stringConf.createOptional
- private[spark] val EXECUTOR_INSTANCES = ConfigBuilder("spark.executor.instances").intConf.optional
+ private[spark] val EXECUTOR_INSTANCES = ConfigBuilder("spark.executor.instances")
+ .intConf
+ .createOptional
+ private[spark] val PY_FILES = ConfigBuilder("spark.submit.pyFiles")
+ .internal()
+ .stringConf
+ .toSequence
+ .createWithDefault(Nil)
}
diff --git a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala
index a2add61617..31b9c5edf0 100644
--- a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala
+++ b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala
@@ -37,7 +37,6 @@ private[spark] class WorkerCommandBuilder(sparkHome: String, memoryMb: Int, comm
override def buildCommand(env: JMap[String, String]): JList[String] = {
val cmd = buildJavaCommand(command.classPathEntries.mkString(File.pathSeparator))
- cmd.add(s"-Xms${memoryMb}M")
cmd.add(s"-Xmx${memoryMb}M")
command.javaOpts.foreach(cmd.add)
CommandBuilderUtils.addPermGenSizeOpt(cmd)
diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
index 891facba33..607283a306 100644
--- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
+++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
@@ -33,11 +33,8 @@ object SparkHadoopMapRedUtil extends Logging {
* the driver in order to determine whether this attempt can commit (please see SPARK-4879 for
* details).
*
- * Output commit coordinator is only contacted when the following two configurations are both set
- * to `true`:
- *
- * - `spark.speculation`
- * - `spark.hadoop.outputCommitCoordination.enabled`
+ * Output commit coordinator is only used when `spark.hadoop.outputCommitCoordination.enabled`
+ * is set to true (which is the default).
*/
def commitTask(
committer: MapReduceOutputCommitter,
@@ -64,11 +61,10 @@ object SparkHadoopMapRedUtil extends Logging {
if (committer.needsTaskCommit(mrTaskContext)) {
val shouldCoordinateWithDriver: Boolean = {
val sparkConf = SparkEnv.get.conf
- // We only need to coordinate with the driver if there are multiple concurrent task
- // attempts, which should only occur if speculation is enabled
- val speculationEnabled = sparkConf.getBoolean("spark.speculation", defaultValue = false)
- // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs
- sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", speculationEnabled)
+ // We only need to coordinate with the driver if there are concurrent task attempts.
+ // Note that this could happen even when speculation is not enabled (e.g. see SPARK-8029).
+ // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs.
+ sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", defaultValue = true)
}
if (shouldCoordinateWithDriver) {
diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
index 10656bc8c8..0210217e41 100644
--- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
@@ -23,6 +23,7 @@ import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.storage.BlockId
import org.apache.spark.storage.memory.MemoryStore
+import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.memory.MemoryAllocator
@@ -190,6 +191,8 @@ private[spark] abstract class MemoryManager(
if (conf.getBoolean("spark.memory.offHeap.enabled", false)) {
require(conf.getSizeAsBytes("spark.memory.offHeap.size", 0) > 0,
"spark.memory.offHeap.size must be > 0 when spark.memory.offHeap.enabled == true")
+ require(Platform.unaligned(),
+ "No support for unaligned Unsafe. Set spark.memory.offHeap.enabled to false.")
MemoryMode.OFF_HEAP
} else {
MemoryMode.ON_HEAP
diff --git a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala
index a67e8da26b..0b552cabfc 100644
--- a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala
+++ b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala
@@ -35,6 +35,11 @@ private[memory] class StorageMemoryPool(
memoryMode: MemoryMode
) extends MemoryPool(lock) with Logging {
+ private[this] val poolName: String = memoryMode match {
+ case MemoryMode.ON_HEAP => "on-heap storage"
+ case MemoryMode.OFF_HEAP => "off-heap storage"
+ }
+
@GuardedBy("lock")
private[this] var _memoryUsed: Long = 0L
@@ -60,7 +65,7 @@ private[memory] class StorageMemoryPool(
/**
* Acquire N bytes of memory to cache the given block, evicting existing ones if necessary.
- *
+ *
* @return whether all N bytes were successfully granted.
*/
def acquireMemory(blockId: BlockId, numBytes: Long): Boolean = lock.synchronized {
@@ -83,9 +88,8 @@ private[memory] class StorageMemoryPool(
assert(numBytesToAcquire >= 0)
assert(numBytesToFree >= 0)
assert(memoryUsed <= poolSize)
- // Once we support off-heap caching, this will need to change:
- if (numBytesToFree > 0 && memoryMode == MemoryMode.ON_HEAP) {
- memoryStore.evictBlocksToFreeSpace(Some(blockId), numBytesToFree)
+ if (numBytesToFree > 0) {
+ memoryStore.evictBlocksToFreeSpace(Some(blockId), numBytesToFree, memoryMode)
}
// NOTE: If the memory store evicts blocks, then those evictions will synchronously call
// back into this StorageMemoryPool in order to free memory. Therefore, these variables
@@ -122,14 +126,8 @@ private[memory] class StorageMemoryPool(
val remainingSpaceToFree = spaceToFree - spaceFreedByReleasingUnusedMemory
if (remainingSpaceToFree > 0) {
// If reclaiming free memory did not adequately shrink the pool, begin evicting blocks:
- val spaceFreedByEviction = {
- // Once we support off-heap caching, this will need to change:
- if (memoryMode == MemoryMode.ON_HEAP) {
- memoryStore.evictBlocksToFreeSpace(None, remainingSpaceToFree)
- } else {
- 0
- }
- }
+ val spaceFreedByEviction =
+ memoryStore.evictBlocksToFreeSpace(None, remainingSpaceToFree, memoryMode)
// When a block is released, BlockManager.dropFromMemory() calls releaseMemory(), so we do
// not need to decrement _memoryUsed here. However, we do need to decrement the pool size.
decrementPoolSize(spaceFreedByEviction)
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
index 4da1017d28..0fed991049 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
@@ -196,10 +196,9 @@ private[spark] class MetricsSystem private (
sinks += sink.asInstanceOf[Sink]
}
} catch {
- case e: Exception => {
+ case e: Exception =>
logError("Sink class " + classPath + " cannot be instantiated")
throw e
- }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
index e43e3a2de2..09ce012e4e 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
@@ -36,7 +36,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
* Initialize the transfer service by giving it the BlockDataManager that can be used to fetch
* local blocks or put local blocks.
*/
- def init(blockDataManager: BlockDataManager)
+ def init(blockDataManager: BlockDataManager): Unit
/**
* Tear down the transfer service.
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index 5f3d4532dd..33a3219607 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -39,7 +39,11 @@ import org.apache.spark.util.Utils
/**
* A BlockTransferService that uses Netty to fetch a set of blocks at at time.
*/
-class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManager, numCores: Int)
+private[spark] class NettyBlockTransferService(
+ conf: SparkConf,
+ securityManager: SecurityManager,
+ override val hostName: String,
+ numCores: Int)
extends BlockTransferService {
// TODO: Don't use Java serialization, use a more cross-version compatible serialization format.
@@ -65,13 +69,13 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
clientFactory = transportContext.createClientFactory(clientBootstrap.toSeq.asJava)
server = createServer(serverBootstrap.toList)
appId = conf.getAppId
- logInfo("Server created on " + server.getPort)
+ logInfo(s"Server created on ${hostName}:${server.getPort}")
}
/** Creates and binds the TransportServer, possibly trying multiple ports. */
private def createServer(bootstraps: List[TransportServerBootstrap]): TransportServer = {
def startService(port: Int): (TransportServer, Int) = {
- val server = transportContext.createServer(port, bootstraps.asJava)
+ val server = transportContext.createServer(hostName, port, bootstraps.asJava)
(server, server.getPort)
}
@@ -109,8 +113,6 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
}
}
- override def hostName: String = Utils.localHostName()
-
override def port: Int = server.getPort
override def uploadBlock(
diff --git a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala
index 48b9434153..ab6aba6fc7 100644
--- a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala
+++ b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala
@@ -21,5 +21,22 @@ package org.apache.spark.partial
* A Double value with error bars and associated confidence.
*/
class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) {
+
override def toString(): String = "[%.3f, %.3f]".format(low, high)
+
+ override def hashCode: Int =
+ this.mean.hashCode ^ this.confidence.hashCode ^ this.low.hashCode ^ this.high.hashCode
+
+ /**
+ * Note that consistent with Double, any NaN value will make equality false
+ */
+ override def equals(that: Any): Boolean =
+ that match {
+ case that: BoundedDouble =>
+ this.mean == that.mean &&
+ this.confidence == that.confidence &&
+ this.low == that.low &&
+ this.high == that.high
+ case _ => false
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala
index 44295e5a1a..5fe3358316 100644
--- a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala
+++ b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala
@@ -29,8 +29,9 @@ import org.apache.spark.util.StatCounter
private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double)
extends ApproximateEvaluator[StatCounter, BoundedDouble] {
+ // modified in merge
var outputsMerged = 0
- var counter = new StatCounter
+ val counter = new StatCounter
override def merge(outputId: Int, taskResult: StatCounter) {
outputsMerged += 1
@@ -40,30 +41,39 @@ private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double)
override def currentResult(): BoundedDouble = {
if (outputsMerged == totalOutputs) {
new BoundedDouble(counter.sum, 1.0, counter.sum, counter.sum)
- } else if (outputsMerged == 0) {
+ } else if (outputsMerged == 0 || counter.count == 0) {
new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
} else {
val p = outputsMerged.toDouble / totalOutputs
val meanEstimate = counter.mean
- val meanVar = counter.sampleVariance / counter.count
val countEstimate = (counter.count + 1 - p) / p
- val countVar = (counter.count + 1) * (1 - p) / (p * p)
val sumEstimate = meanEstimate * countEstimate
- val sumVar = (meanEstimate * meanEstimate * countVar) +
- (countEstimate * countEstimate * meanVar) +
- (meanVar * countVar)
- val sumStdev = math.sqrt(sumVar)
- val confFactor = {
- if (counter.count > 100) {
+
+ val meanVar = counter.sampleVariance / counter.count
+
+ // branch at this point because counter.count == 1 implies counter.sampleVariance == Nan
+ // and we don't want to ever return a bound of NaN
+ if (meanVar.isNaN || counter.count == 1) {
+ new BoundedDouble(sumEstimate, confidence, Double.NegativeInfinity, Double.PositiveInfinity)
+ } else {
+ val countVar = (counter.count + 1) * (1 - p) / (p * p)
+ val sumVar = (meanEstimate * meanEstimate * countVar) +
+ (countEstimate * countEstimate * meanVar) +
+ (meanVar * countVar)
+ val sumStdev = math.sqrt(sumVar)
+ val confFactor = if (counter.count > 100) {
new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2)
} else {
+ // note that if this goes to 0, TDistribution will throw an exception.
+ // Hence special casing 1 above.
val degreesOfFreedom = (counter.count - 1).toInt
new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2)
}
+
+ val low = sumEstimate - confFactor * sumStdev
+ val high = sumEstimate + confFactor * sumStdev
+ new BoundedDouble(sumEstimate, confidence, low, high)
}
- val low = sumEstimate - confFactor * sumStdev
- val high = sumEstimate + confFactor * sumStdev
- new BoundedDouble(sumEstimate, confidence, low, high)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
index 8358244987..63d1d1767a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
@@ -35,9 +35,9 @@ class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[Blo
override def getPartitions: Array[Partition] = {
assertValid()
- (0 until blockIds.length).map(i => {
+ (0 until blockIds.length).map { i =>
new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition]
- }).toArray
+ }.toArray
}
override def compute(split: Partition, context: TaskContext): Iterator[T] = {
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index e5ebc63082..7bc1eb0436 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -29,10 +29,12 @@ import org.apache.spark.serializer.Serializer
import org.apache.spark.util.collection.{CompactBuffer, ExternalAppendOnlyMap}
import org.apache.spark.util.Utils
-/** The references to rdd and splitIndex are transient because redundant information is stored
- * in the CoGroupedRDD object. Because CoGroupedRDD is serialized separately from
- * CoGroupPartition, if rdd and splitIndex aren't transient, they'll be included twice in the
- * task closure. */
+/**
+ * The references to rdd and splitIndex are transient because redundant information is stored
+ * in the CoGroupedRDD object. Because CoGroupedRDD is serialized separately from
+ * CoGroupPartition, if rdd and splitIndex aren't transient, they'll be included twice in the
+ * task closure.
+ */
private[spark] case class NarrowCoGroupSplitDep(
@transient rdd: RDD[_],
@transient splitIndex: Int,
diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
index 5e9230e733..368916a39e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
@@ -166,8 +166,8 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
val counters = new Array[Long](buckets.length - 1)
while (iter.hasNext) {
bucketFunction(iter.next()) match {
- case Some(x: Int) => {counters(x) += 1}
- case _ => {}
+ case Some(x: Int) => counters(x) += 1
+ case _ => // No-Op
}
}
Iterator(counters)
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 08db96edd6..35d190b464 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -213,15 +213,13 @@ class HadoopRDD[K, V](
logInfo("Input split: " + split.inputSplit)
val jobConf = getJobConf()
- // TODO: there is a lot of duplicate code between this and NewHadoopRDD and SqlNewHadoopRDD
-
val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop)
val existingBytesRead = inputMetrics.bytesRead
// Sets the thread local variable for the file's name
split.inputSplit.value match {
- case fs: FileSplit => SqlNewHadoopRDDState.setInputFileName(fs.getPath.toString)
- case _ => SqlNewHadoopRDDState.unsetInputFileName()
+ case fs: FileSplit => InputFileNameHolder.setInputFileName(fs.getPath.toString)
+ case _ => InputFileNameHolder.unsetInputFileName()
}
// Find a function that will return the FileSystem bytes read by this thread. Do this before
@@ -261,7 +259,7 @@ class HadoopRDD[K, V](
finished = true
}
if (!finished) {
- inputMetrics.incRecordsReadInternal(1)
+ inputMetrics.incRecordsRead(1)
}
if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) {
updateBytesRead()
@@ -271,7 +269,7 @@ class HadoopRDD[K, V](
override def close() {
if (reader != null) {
- SqlNewHadoopRDDState.unsetInputFileName()
+ InputFileNameHolder.unsetInputFileName()
// Close the reader and release it. Note: it's very important that we don't close the
// reader more than once, since that exposes us to MAPREDUCE-5918 when running against
// Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic
@@ -293,7 +291,7 @@ class HadoopRDD[K, V](
// If we can't get the bytes read from the FS stats, fall back to the split size,
// which may be inaccurate.
try {
- inputMetrics.incBytesReadInternal(split.inputSplit.value.getLength)
+ inputMetrics.incBytesRead(split.inputSplit.value.getLength)
} catch {
case e: java.io.IOException =>
logWarning("Unable to get input size to set InputMetrics for task", e)
@@ -424,7 +422,7 @@ private[spark] object HadoopRDD extends Logging {
private[spark] def convertSplitLocationInfo(infos: Array[AnyRef]): Seq[String] = {
val out = ListBuffer[String]()
- infos.foreach { loc => {
+ infos.foreach { loc =>
val locationStr = HadoopRDD.SPLIT_INFO_REFLECTIONS.get.
getLocation.invoke(loc).asInstanceOf[String]
if (locationStr != "localhost") {
@@ -436,7 +434,7 @@ private[spark] object HadoopRDD extends Logging {
out += new HostTaskLocation(locationStr).toString
}
}
- }}
+ }
out.seq
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala b/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala
index 3f15fff793..108e9d2558 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala
@@ -20,10 +20,10 @@ package org.apache.spark.rdd
import org.apache.spark.unsafe.types.UTF8String
/**
- * State for SqlNewHadoopRDD objects. This is split this way because of the package splits.
- * TODO: Move/Combine this with org.apache.spark.sql.datasources.SqlNewHadoopRDD
+ * This holds file names of the current Spark task. This is used in HadoopRDD,
+ * FileScanRDD and InputFileName function in Spark SQL.
*/
-private[spark] object SqlNewHadoopRDDState {
+private[spark] object InputFileNameHolder {
/**
* The thread variable for the name of the current file being read. This is used by
* the InputFileName function in Spark SQL.
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index fb9606ae38..3ccd616cbf 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -189,7 +189,7 @@ class NewHadoopRDD[K, V](
}
havePair = false
if (!finished) {
- inputMetrics.incRecordsReadInternal(1)
+ inputMetrics.incRecordsRead(1)
}
if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) {
updateBytesRead()
@@ -220,7 +220,7 @@ class NewHadoopRDD[K, V](
// If we can't get the bytes read from the FS stats, fall back to the split size,
// which may be inaccurate.
try {
- inputMetrics.incBytesReadInternal(split.serializableHadoopSplit.value.getLength)
+ inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength)
} catch {
case e: java.io.IOException =>
logWarning("Unable to get input size to set InputMetrics for task", e)
diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
index 363004e587..a5992022d0 100644
--- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
@@ -86,12 +86,11 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
def inRange(k: K): Boolean = ordering.gteq(k, lower) && ordering.lteq(k, upper)
val rddToFilter: RDD[P] = self.partitioner match {
- case Some(rp: RangePartitioner[K, V]) => {
+ case Some(rp: RangePartitioner[K, V]) =>
val partitionIndicies = (rp.getPartition(lower), rp.getPartition(upper)) match {
case (l, u) => Math.min(l, u) to Math.max(l, u)
}
PartitionPruningRDD.create(self, partitionIndicies.contains)
- }
case _ =>
self
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 296179b75b..085829af6e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -1111,9 +1111,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
maybeUpdateOutputMetrics(outputMetricsAndBytesWrittenCallback, recordsWritten)
recordsWritten += 1
}
- } {
- writer.close(hadoopContext)
- }
+ }(finallyBlock = writer.close(hadoopContext))
committer.commitTask(hadoopContext)
outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) =>
om.setBytesWritten(callback())
@@ -1200,9 +1198,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
maybeUpdateOutputMetrics(outputMetricsAndBytesWrittenCallback, recordsWritten)
recordsWritten += 1
}
- } {
- writer.close()
- }
+ }(finallyBlock = writer.close())
writer.commit()
outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) =>
om.setBytesWritten(callback())
diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
index 582fa93afe..bb84e4af15 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
@@ -121,14 +121,14 @@ private object ParallelCollectionRDD {
// Sequences need to be sliced at the same set of index positions for operations
// like RDD.zip() to behave as expected
def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = {
- (0 until numSlices).iterator.map(i => {
+ (0 until numSlices).iterator.map { i =>
val start = ((i * length) / numSlices).toInt
val end = (((i + 1) * length) / numSlices).toInt
(start, end)
- })
+ }
}
seq match {
- case r: Range => {
+ case r: Range =>
positions(r.length, numSlices).zipWithIndex.map({ case ((start, end), index) =>
// If the range is inclusive, use inclusive range for the last slice
if (r.isInclusive && index == numSlices - 1) {
@@ -138,8 +138,7 @@ private object ParallelCollectionRDD {
new Range(r.start + start * r.step, r.start + end * r.step, r.step)
}
}).toSeq.asInstanceOf[Seq[Seq[T]]]
- }
- case nr: NumericRange[_] => {
+ case nr: NumericRange[_] =>
// For ranges of Long, Double, BigInteger, etc
val slices = new ArrayBuffer[Seq[T]](numSlices)
var r = nr
@@ -149,14 +148,12 @@ private object ParallelCollectionRDD {
r = r.drop(sliceSize)
}
slices
- }
- case _ => {
+ case _ =>
val array = seq.toArray // To prevent O(n^2) operations for List etc
positions(array.length, numSlices).map({
case (start, end) =>
array.slice(start, end).toSeq
}).toSeq
- }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
index 9e3880714a..0abba15bec 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
@@ -68,9 +68,9 @@ class PartitionerAwareUnionRDD[T: ClassTag](
override def getPartitions: Array[Partition] = {
val numPartitions = partitioner.get.numPartitions
- (0 until numPartitions).map(index => {
+ (0 until numPartitions).map { index =>
new PartitionerAwareUnionRDDPartition(rdds, index)
- }).toArray
+ }.toArray
}
// Get the location where most of the partitions of parent RDDs are located
@@ -78,11 +78,10 @@ class PartitionerAwareUnionRDD[T: ClassTag](
logDebug("Finding preferred location for " + this + ", partition " + s.index)
val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents
val locations = rdds.zip(parentPartitions).flatMap {
- case (rdd, part) => {
+ case (rdd, part) =>
val parentLocations = currPrefLocs(rdd, part)
logDebug("Location of " + rdd + " partition " + part.index + " = " + parentLocations)
parentLocations
- }
}
val location = if (locations.isEmpty) {
None
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index f96551c793..36ff3bcaae 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -255,8 +255,8 @@ abstract class RDD[T: ClassTag](
}
/**
- * Returns the number of partitions of this RDD.
- */
+ * Returns the number of partitions of this RDD.
+ */
@Since("1.6.0")
final def getNumPartitions: Int = partitions.length
@@ -333,10 +333,10 @@ abstract class RDD[T: ClassTag](
case Left(blockResult) =>
if (readCachedBlock) {
val existingMetrics = context.taskMetrics().registerInputMetrics(blockResult.readMethod)
- existingMetrics.incBytesReadInternal(blockResult.bytes)
+ existingMetrics.incBytesRead(blockResult.bytes)
new InterruptibleIterator[T](context, blockResult.data.asInstanceOf[Iterator[T]]) {
override def next(): T = {
- existingMetrics.incRecordsReadInternal(1)
+ existingMetrics.incRecordsRead(1)
delegate.next()
}
}
@@ -568,11 +568,7 @@ abstract class RDD[T: ClassTag](
* times (use `.distinct()` to eliminate them).
*/
def union(other: RDD[T]): RDD[T] = withScope {
- if (partitioner.isDefined && other.partitioner == partitioner) {
- new PartitionerAwareUnionRDD(sc, Array(this, other))
- } else {
- new UnionRDD(sc, Array(this, other))
- }
+ sc.union(this, other)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 5cdc91316b..c27aad268d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -950,13 +950,6 @@ class DAGScheduler(
// First figure out the indexes of partition ids to compute.
val partitionsToCompute: Seq[Int] = stage.findMissingPartitions()
- // Create internal accumulators if the stage has no accumulators initialized.
- // Reset internal accumulators only if this stage is not partially submitted
- // Otherwise, we may override existing accumulator values from some tasks
- if (stage.internalAccumulators.isEmpty || stage.numPartitions == partitionsToCompute.size) {
- stage.resetInternalAccumulators()
- }
-
// Use the scheduling pool, job group, description, etc. from an ActiveJob associated
// with this Stage
val properties = jobIdToActiveJob(jobId).properties
@@ -1036,7 +1029,7 @@ class DAGScheduler(
val locs = taskIdToLocations(id)
val part = stage.rdd.partitions(id)
new ShuffleMapTask(stage.id, stage.latestInfo.attemptId,
- taskBinary, part, locs, stage.internalAccumulators)
+ taskBinary, part, locs, stage.latestInfo.internalAccumulators, properties)
}
case stage: ResultStage =>
@@ -1046,7 +1039,7 @@ class DAGScheduler(
val part = stage.rdd.partitions(p)
val locs = taskIdToLocations(id)
new ResultTask(stage.id, stage.latestInfo.attemptId,
- taskBinary, part, locs, id, stage.internalAccumulators)
+ taskBinary, part, locs, id, properties, stage.latestInfo.internalAccumulators)
}
}
} catch {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
index 0640f26051..a6b032cc00 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
@@ -57,11 +57,10 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl
// Since we are not doing canonicalization of path, this can be wrong : like relative vs
// absolute path .. which is fine, this is best case effort to remove duplicates - right ?
override def equals(other: Any): Boolean = other match {
- case that: InputFormatInfo => {
+ case that: InputFormatInfo =>
// not checking config - that should be fine, right ?
this.inputFormatClazz == that.inputFormatClazz &&
this.path == that.path
- }
case _ => false
}
@@ -86,10 +85,9 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl
}
}
catch {
- case e: ClassNotFoundException => {
+ case e: ClassNotFoundException =>
throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz +
" cannot be found ?", e)
- }
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala b/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala
index 50c2b9acd6..e0f7c8f021 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala
@@ -23,6 +23,6 @@ package org.apache.spark.scheduler
* job fails (and no further taskSucceeded events will happen).
*/
private[spark] trait JobListener {
- def taskSucceeded(index: Int, result: Any)
- def jobFailed(exception: Exception)
+ def taskSucceeded(index: Int, result: Any): Unit
+ def jobFailed(exception: Exception): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index cd2736e196..db6276f75d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -19,6 +19,7 @@ package org.apache.spark.scheduler
import java.io._
import java.nio.ByteBuffer
+import java.util.Properties
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
@@ -38,6 +39,7 @@ import org.apache.spark.rdd.RDD
* @param locs preferred task execution locations for locality scheduling
* @param outputId index of the task in this job (a job can launch tasks on only a subset of the
* input RDD's partitions).
+ * @param localProperties copy of thread-local properties set by the user on the driver side.
* @param _initialAccums initial set of accumulators to be used in this task for tracking
* internal metrics. Other accumulators will be registered later when
* they are deserialized on the executors.
@@ -49,8 +51,9 @@ private[spark] class ResultTask[T, U](
partition: Partition,
locs: Seq[TaskLocation],
val outputId: Int,
+ localProperties: Properties,
_initialAccums: Seq[Accumulator[_]] = InternalAccumulator.createAll())
- extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums)
+ extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums, localProperties)
with Serializable {
@transient private[this] val preferredLocs: Seq[TaskLocation] = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
index 5baebe8c1f..100ed76ecb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
@@ -34,9 +34,9 @@ import org.apache.spark.util.Utils
private[spark] trait SchedulableBuilder {
def rootPool: Pool
- def buildPools()
+ def buildPools(): Unit
- def addTaskSetManager(manager: Schedulable, properties: Properties)
+ def addTaskSetManager(manager: Schedulable, properties: Properties): Unit
}
private[spark] class FIFOSchedulableBuilder(val rootPool: Pool)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index e30964a01b..b7cab7013e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -18,6 +18,7 @@
package org.apache.spark.scheduler
import java.nio.ByteBuffer
+import java.util.Properties
import scala.language.existentials
@@ -42,6 +43,7 @@ import org.apache.spark.shuffle.ShuffleWriter
* @param _initialAccums initial set of accumulators to be used in this task for tracking
* internal metrics. Other accumulators will be registered later when
* they are deserialized on the executors.
+ * @param localProperties copy of thread-local properties set by the user on the driver side.
*/
private[spark] class ShuffleMapTask(
stageId: Int,
@@ -49,13 +51,14 @@ private[spark] class ShuffleMapTask(
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient private var locs: Seq[TaskLocation],
- _initialAccums: Seq[Accumulator[_]])
- extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums)
+ _initialAccums: Seq[Accumulator[_]],
+ localProperties: Properties)
+ extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums, localProperties)
with Logging {
/** A constructor used only in test suites. This does not require passing in an RDD. */
def this(partitionId: Int) {
- this(0, 0, null, new Partition { override def index: Int = 0 }, null, null)
+ this(0, 0, null, new Partition { override def index: Int = 0 }, null, null, new Properties)
}
@transient private val preferredLocs: Seq[TaskLocation] = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 586173f180..080ea6c33a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -151,275 +151,152 @@ private[spark] trait SparkHistoryListenerFactory {
def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener]
}
+
/**
- * :: DeveloperApi ::
- * Interface for listening to events from the Spark scheduler. Note that this is an internal
- * interface which might change in different Spark releases. Java clients should extend
- * {@link JavaSparkListener}
+ * Interface for listening to events from the Spark scheduler. Most applications should probably
+ * extend SparkListener or SparkFirehoseListener directly, rather than implementing this class.
+ *
+ * Note that this is an internal interface which might change in different Spark releases.
*/
-@DeveloperApi
-trait SparkListener {
+private[spark] trait SparkListenerInterface {
+
/**
* Called when a stage completes successfully or fails, with information on the completed stage.
*/
- def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { }
+ def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit
/**
* Called when a stage is submitted
*/
- def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { }
+ def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit
/**
* Called when a task starts
*/
- def onTaskStart(taskStart: SparkListenerTaskStart) { }
+ def onTaskStart(taskStart: SparkListenerTaskStart): Unit
/**
* Called when a task begins remotely fetching its result (will not be called for tasks that do
* not need to fetch the result remotely).
*/
- def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { }
+ def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult): Unit
/**
* Called when a task ends
*/
- def onTaskEnd(taskEnd: SparkListenerTaskEnd) { }
+ def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit
/**
* Called when a job starts
*/
- def onJobStart(jobStart: SparkListenerJobStart) { }
+ def onJobStart(jobStart: SparkListenerJobStart): Unit
/**
* Called when a job ends
*/
- def onJobEnd(jobEnd: SparkListenerJobEnd) { }
+ def onJobEnd(jobEnd: SparkListenerJobEnd): Unit
/**
* Called when environment properties have been updated
*/
- def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { }
+ def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate): Unit
/**
* Called when a new block manager has joined
*/
- def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded) { }
+ def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit
/**
* Called when an existing block manager has been removed
*/
- def onBlockManagerRemoved(blockManagerRemoved: SparkListenerBlockManagerRemoved) { }
+ def onBlockManagerRemoved(blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit
/**
* Called when an RDD is manually unpersisted by the application
*/
- def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD) { }
+ def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit
/**
* Called when the application starts
*/
- def onApplicationStart(applicationStart: SparkListenerApplicationStart) { }
+ def onApplicationStart(applicationStart: SparkListenerApplicationStart): Unit
/**
* Called when the application ends
*/
- def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd) { }
+ def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit
/**
* Called when the driver receives task metrics from an executor in a heartbeat.
*/
- def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { }
+ def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit
/**
* Called when the driver registers a new executor.
*/
- def onExecutorAdded(executorAdded: SparkListenerExecutorAdded) { }
+ def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit
/**
* Called when the driver removes an executor.
*/
- def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved) { }
+ def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit
/**
* Called when the driver receives a block update info.
*/
- def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated) { }
+ def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit
/**
* Called when other events like SQL-specific events are posted.
*/
- def onOtherEvent(event: SparkListenerEvent) { }
+ def onOtherEvent(event: SparkListenerEvent): Unit
}
+
/**
* :: DeveloperApi ::
- * Simple SparkListener that logs a few summary statistics when each stage completes
+ * A default implementation for [[SparkListenerInterface]] that has no-op implementations for
+ * all callbacks.
+ *
+ * Note that this is an internal interface which might change in different Spark releases.
*/
@DeveloperApi
-class StatsReportListener extends SparkListener with Logging {
-
- import org.apache.spark.scheduler.StatsReportListener._
-
- private val taskInfoMetrics = mutable.Buffer[(TaskInfo, TaskMetrics)]()
-
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
- val info = taskEnd.taskInfo
- val metrics = taskEnd.taskMetrics
- if (info != null && metrics != null) {
- taskInfoMetrics += ((info, metrics))
- }
- }
-
- override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) {
- implicit val sc = stageCompleted
- this.logInfo(s"Finished stage: ${getStatusDetail(stageCompleted.stageInfo)}")
- showMillisDistribution("task runtime:", (info, _) => Some(info.duration), taskInfoMetrics)
-
- // Shuffle write
- showBytesDistribution("shuffle bytes written:",
- (_, metric) => metric.shuffleWriteMetrics.map(_.bytesWritten), taskInfoMetrics)
-
- // Fetch & I/O
- showMillisDistribution("fetch wait time:",
- (_, metric) => metric.shuffleReadMetrics.map(_.fetchWaitTime), taskInfoMetrics)
- showBytesDistribution("remote bytes read:",
- (_, metric) => metric.shuffleReadMetrics.map(_.remoteBytesRead), taskInfoMetrics)
- showBytesDistribution("task result size:",
- (_, metric) => Some(metric.resultSize), taskInfoMetrics)
-
- // Runtime breakdown
- val runtimePcts = taskInfoMetrics.map { case (info, metrics) =>
- RuntimePercentage(info.duration, metrics)
- }
- showDistribution("executor (non-fetch) time pct: ",
- Distribution(runtimePcts.map(_.executorPct * 100)), "%2.0f %%")
- showDistribution("fetch wait time pct: ",
- Distribution(runtimePcts.flatMap(_.fetchPct.map(_ * 100))), "%2.0f %%")
- showDistribution("other time pct: ", Distribution(runtimePcts.map(_.other * 100)), "%2.0f %%")
- taskInfoMetrics.clear()
- }
-
- private def getStatusDetail(info: StageInfo): String = {
- val failureReason = info.failureReason.map("(" + _ + ")").getOrElse("")
- val timeTaken = info.submissionTime.map(
- x => info.completionTime.getOrElse(System.currentTimeMillis()) - x
- ).getOrElse("-")
-
- s"Stage(${info.stageId}, ${info.attemptId}); Name: '${info.name}'; " +
- s"Status: ${info.getStatusString}$failureReason; numTasks: ${info.numTasks}; " +
- s"Took: $timeTaken msec"
- }
+abstract class SparkListener extends SparkListenerInterface {
+ override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { }
-}
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { }
-private[spark] object StatsReportListener extends Logging {
-
- // For profiling, the extremes are more interesting
- val percentiles = Array[Int](0, 5, 10, 25, 50, 75, 90, 95, 100)
- val probabilities = percentiles.map(_ / 100.0)
- val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%"
-
- def extractDoubleDistribution(
- taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)],
- getMetric: (TaskInfo, TaskMetrics) => Option[Double]): Option[Distribution] = {
- Distribution(taskInfoMetrics.flatMap { case (info, metric) => getMetric(info, metric) })
- }
-
- // Is there some way to setup the types that I can get rid of this completely?
- def extractLongDistribution(
- taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)],
- getMetric: (TaskInfo, TaskMetrics) => Option[Long]): Option[Distribution] = {
- extractDoubleDistribution(
- taskInfoMetrics,
- (info, metric) => { getMetric(info, metric).map(_.toDouble) })
- }
-
- def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) {
- val stats = d.statCounter
- val quantiles = d.getQuantiles(probabilities).map(formatNumber)
- logInfo(heading + stats)
- logInfo(percentilesHeader)
- logInfo("\t" + quantiles.mkString("\t"))
- }
-
- def showDistribution(
- heading: String,
- dOpt: Option[Distribution],
- formatNumber: Double => String) {
- dOpt.foreach { d => showDistribution(heading, d, formatNumber)}
- }
-
- def showDistribution(heading: String, dOpt: Option[Distribution], format: String) {
- def f(d: Double): String = format.format(d)
- showDistribution(heading, dOpt, f _)
- }
-
- def showDistribution(
- heading: String,
- format: String,
- getMetric: (TaskInfo, TaskMetrics) => Option[Double],
- taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) {
- showDistribution(heading, extractDoubleDistribution(taskInfoMetrics, getMetric), format)
- }
-
- def showBytesDistribution(
- heading: String,
- getMetric: (TaskInfo, TaskMetrics) => Option[Long],
- taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) {
- showBytesDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric))
- }
-
- def showBytesDistribution(heading: String, dOpt: Option[Distribution]) {
- dOpt.foreach { dist => showBytesDistribution(heading, dist) }
- }
-
- def showBytesDistribution(heading: String, dist: Distribution) {
- showDistribution(heading, dist, (d => Utils.bytesToString(d.toLong)): Double => String)
- }
-
- def showMillisDistribution(heading: String, dOpt: Option[Distribution]) {
- showDistribution(heading, dOpt,
- (d => StatsReportListener.millisToString(d.toLong)): Double => String)
- }
-
- def showMillisDistribution(
- heading: String,
- getMetric: (TaskInfo, TaskMetrics) => Option[Long],
- taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) {
- showMillisDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric))
- }
-
- val seconds = 1000L
- val minutes = seconds * 60
- val hours = minutes * 60
+ override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { }
- /**
- * Reformat a time interval in milliseconds to a prettier format for output
- */
- def millisToString(ms: Long): String = {
- val (size, units) =
- if (ms > hours) {
- (ms.toDouble / hours, "hours")
- } else if (ms > minutes) {
- (ms.toDouble / minutes, "min")
- } else if (ms > seconds) {
- (ms.toDouble / seconds, "s")
- } else {
- (ms.toDouble, "ms")
- }
- "%.1f %s".format(size, units)
- }
-}
+ override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult): Unit = { }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { }
+
+ override def onJobStart(jobStart: SparkListenerJobStart): Unit = { }
+
+ override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { }
+
+ override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate): Unit = { }
+
+ override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = { }
+
+ override def onBlockManagerRemoved(
+ blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = { }
+
+ override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = { }
+
+ override def onApplicationStart(applicationStart: SparkListenerApplicationStart): Unit = { }
+
+ override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { }
+
+ override def onExecutorMetricsUpdate(
+ executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = { }
+
+ override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = { }
+
+ override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = { }
+
+ override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { }
-private case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double)
-
-private object RuntimePercentage {
- def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = {
- val denom = totalTime.toDouble
- val fetchTime = metrics.shuffleReadMetrics.map(_.fetchWaitTime)
- val fetch = fetchTime.map(_ / denom)
- val exec = (metrics.executorRunTime - fetchTime.getOrElse(0L)) / denom
- val other = 1.0 - (exec + fetch.getOrElse(0d))
- RuntimePercentage(exec, fetch, other)
- }
+ override def onOtherEvent(event: SparkListenerEvent): Unit = { }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
index 94f0574f0e..471586ac08 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
@@ -22,9 +22,12 @@ import org.apache.spark.util.ListenerBus
/**
* A [[SparkListenerEvent]] bus that relays [[SparkListenerEvent]]s to its listeners
*/
-private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkListenerEvent] {
+private[spark] trait SparkListenerBus
+ extends ListenerBus[SparkListenerInterface, SparkListenerEvent] {
- protected override def doPostEvent(listener: SparkListener, event: SparkListenerEvent): Unit = {
+ protected override def doPostEvent(
+ listener: SparkListenerInterface,
+ event: SparkListenerEvent): Unit = {
event match {
case stageSubmitted: SparkListenerStageSubmitted =>
listener.onStageSubmitted(stageSubmitted)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala
index 6e9337bb90..bc1431835e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala
@@ -49,14 +49,13 @@ class SplitInfo(
// So unless there is identity equality between underlyingSplits, it will always fail even if it
// is pointing to same block.
override def equals(other: Any): Boolean = other match {
- case that: SplitInfo => {
+ case that: SplitInfo =>
this.hostLocation == that.hostLocation &&
this.inputFormatClazz == that.inputFormatClazz &&
this.path == that.path &&
this.length == that.length &&
// other split specific checks (like start for FileSplit)
this.underlyingSplit == that.underlyingSplit
- }
case _ => false
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index a40b700cdd..b6d4e39fe5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -75,22 +75,6 @@ private[scheduler] abstract class Stage(
val name: String = callSite.shortForm
val details: String = callSite.longForm
- private var _internalAccumulators: Seq[Accumulator[_]] = Seq.empty
-
- /** Internal accumulators shared across all tasks in this stage. */
- def internalAccumulators: Seq[Accumulator[_]] = _internalAccumulators
-
- /**
- * Re-initialize the internal accumulators associated with this stage.
- *
- * This is called every time the stage is submitted, *except* when a subset of tasks
- * belonging to this stage has already finished. Otherwise, reinitializing the internal
- * accumulators here again will override partial values from the finished tasks.
- */
- def resetInternalAccumulators(): Unit = {
- _internalAccumulators = InternalAccumulator.create(rdd.sparkContext)
- }
-
/**
* Pointer to the [StageInfo] object for the most recent attempt. This needs to be initialized
* here, before any attempts have actually been created, because the DAGScheduler uses this
@@ -127,7 +111,8 @@ private[scheduler] abstract class Stage(
numPartitionsToCompute: Int,
taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty): Unit = {
_latestInfo = StageInfo.fromStage(
- this, nextAttemptId, Some(numPartitionsToCompute), taskLocalityPreferences)
+ this, nextAttemptId, Some(numPartitionsToCompute),
+ InternalAccumulator.createAll(rdd.sparkContext), taskLocalityPreferences)
nextAttemptId += 1
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
index 24796c1430..0fd58c41cd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -19,6 +19,7 @@ package org.apache.spark.scheduler
import scala.collection.mutable.HashMap
+import org.apache.spark.Accumulator
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.storage.RDDInfo
@@ -35,6 +36,7 @@ class StageInfo(
val rddInfos: Seq[RDDInfo],
val parentIds: Seq[Int],
val details: String,
+ val internalAccumulators: Seq[Accumulator[_]] = Seq.empty,
private[spark] val taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty) {
/** When this stage was submitted from the DAGScheduler to a TaskScheduler. */
var submissionTime: Option[Long] = None
@@ -42,7 +44,11 @@ class StageInfo(
var completionTime: Option[Long] = None
/** If the stage failed, the reason why. */
var failureReason: Option[String] = None
- /** Terminal values of accumulables updated during this stage. */
+
+ /**
+ * Terminal values of accumulables updated during this stage, including all the user-defined
+ * accumulators.
+ */
val accumulables = HashMap[Long, AccumulableInfo]()
def stageFailed(reason: String) {
@@ -75,6 +81,7 @@ private[spark] object StageInfo {
stage: Stage,
attemptId: Int,
numTasks: Option[Int] = None,
+ internalAccumulators: Seq[Accumulator[_]] = Seq.empty,
taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty
): StageInfo = {
val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd)
@@ -87,6 +94,7 @@ private[spark] object StageInfo {
rddInfos,
stage.parents.map(_.id),
stage.details,
+ internalAccumulators,
taskLocalityPreferences)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala
new file mode 100644
index 0000000000..309f4b806b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import scala.collection.mutable
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.{Distribution, Utils}
+
+
+/**
+ * :: DeveloperApi ::
+ * Simple SparkListener that logs a few summary statistics when each stage completes.
+ */
+@DeveloperApi
+class StatsReportListener extends SparkListener with Logging {
+
+ import org.apache.spark.scheduler.StatsReportListener._
+
+ private val taskInfoMetrics = mutable.Buffer[(TaskInfo, TaskMetrics)]()
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ val info = taskEnd.taskInfo
+ val metrics = taskEnd.taskMetrics
+ if (info != null && metrics != null) {
+ taskInfoMetrics += ((info, metrics))
+ }
+ }
+
+ override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) {
+ implicit val sc = stageCompleted
+ this.logInfo(s"Finished stage: ${getStatusDetail(stageCompleted.stageInfo)}")
+ showMillisDistribution("task runtime:", (info, _) => Some(info.duration), taskInfoMetrics)
+
+ // Shuffle write
+ showBytesDistribution("shuffle bytes written:",
+ (_, metric) => metric.shuffleWriteMetrics.map(_.bytesWritten), taskInfoMetrics)
+
+ // Fetch & I/O
+ showMillisDistribution("fetch wait time:",
+ (_, metric) => metric.shuffleReadMetrics.map(_.fetchWaitTime), taskInfoMetrics)
+ showBytesDistribution("remote bytes read:",
+ (_, metric) => metric.shuffleReadMetrics.map(_.remoteBytesRead), taskInfoMetrics)
+ showBytesDistribution("task result size:",
+ (_, metric) => Some(metric.resultSize), taskInfoMetrics)
+
+ // Runtime breakdown
+ val runtimePcts = taskInfoMetrics.map { case (info, metrics) =>
+ RuntimePercentage(info.duration, metrics)
+ }
+ showDistribution("executor (non-fetch) time pct: ",
+ Distribution(runtimePcts.map(_.executorPct * 100)), "%2.0f %%")
+ showDistribution("fetch wait time pct: ",
+ Distribution(runtimePcts.flatMap(_.fetchPct.map(_ * 100))), "%2.0f %%")
+ showDistribution("other time pct: ", Distribution(runtimePcts.map(_.other * 100)), "%2.0f %%")
+ taskInfoMetrics.clear()
+ }
+
+ private def getStatusDetail(info: StageInfo): String = {
+ val failureReason = info.failureReason.map("(" + _ + ")").getOrElse("")
+ val timeTaken = info.submissionTime.map(
+ x => info.completionTime.getOrElse(System.currentTimeMillis()) - x
+ ).getOrElse("-")
+
+ s"Stage(${info.stageId}, ${info.attemptId}); Name: '${info.name}'; " +
+ s"Status: ${info.getStatusString}$failureReason; numTasks: ${info.numTasks}; " +
+ s"Took: $timeTaken msec"
+ }
+
+}
+
+private[spark] object StatsReportListener extends Logging {
+
+ // For profiling, the extremes are more interesting
+ val percentiles = Array[Int](0, 5, 10, 25, 50, 75, 90, 95, 100)
+ val probabilities = percentiles.map(_ / 100.0)
+ val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%"
+
+ def extractDoubleDistribution(
+ taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)],
+ getMetric: (TaskInfo, TaskMetrics) => Option[Double]): Option[Distribution] = {
+ Distribution(taskInfoMetrics.flatMap { case (info, metric) => getMetric(info, metric) })
+ }
+
+ // Is there some way to setup the types that I can get rid of this completely?
+ def extractLongDistribution(
+ taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)],
+ getMetric: (TaskInfo, TaskMetrics) => Option[Long]): Option[Distribution] = {
+ extractDoubleDistribution(
+ taskInfoMetrics,
+ (info, metric) => { getMetric(info, metric).map(_.toDouble) })
+ }
+
+ def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) {
+ val stats = d.statCounter
+ val quantiles = d.getQuantiles(probabilities).map(formatNumber)
+ logInfo(heading + stats)
+ logInfo(percentilesHeader)
+ logInfo("\t" + quantiles.mkString("\t"))
+ }
+
+ def showDistribution(
+ heading: String,
+ dOpt: Option[Distribution],
+ formatNumber: Double => String) {
+ dOpt.foreach { d => showDistribution(heading, d, formatNumber)}
+ }
+
+ def showDistribution(heading: String, dOpt: Option[Distribution], format: String) {
+ def f(d: Double): String = format.format(d)
+ showDistribution(heading, dOpt, f _)
+ }
+
+ def showDistribution(
+ heading: String,
+ format: String,
+ getMetric: (TaskInfo, TaskMetrics) => Option[Double],
+ taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) {
+ showDistribution(heading, extractDoubleDistribution(taskInfoMetrics, getMetric), format)
+ }
+
+ def showBytesDistribution(
+ heading: String,
+ getMetric: (TaskInfo, TaskMetrics) => Option[Long],
+ taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) {
+ showBytesDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric))
+ }
+
+ def showBytesDistribution(heading: String, dOpt: Option[Distribution]) {
+ dOpt.foreach { dist => showBytesDistribution(heading, dist) }
+ }
+
+ def showBytesDistribution(heading: String, dist: Distribution) {
+ showDistribution(heading, dist, (d => Utils.bytesToString(d.toLong)): Double => String)
+ }
+
+ def showMillisDistribution(heading: String, dOpt: Option[Distribution]) {
+ showDistribution(heading, dOpt,
+ (d => StatsReportListener.millisToString(d.toLong)): Double => String)
+ }
+
+ def showMillisDistribution(
+ heading: String,
+ getMetric: (TaskInfo, TaskMetrics) => Option[Long],
+ taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) {
+ showMillisDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric))
+ }
+
+ val seconds = 1000L
+ val minutes = seconds * 60
+ val hours = minutes * 60
+
+ /**
+ * Reformat a time interval in milliseconds to a prettier format for output
+ */
+ def millisToString(ms: Long): String = {
+ val (size, units) =
+ if (ms > hours) {
+ (ms.toDouble / hours, "hours")
+ } else if (ms > minutes) {
+ (ms.toDouble / minutes, "min")
+ } else if (ms > seconds) {
+ (ms.toDouble / seconds, "s")
+ } else {
+ (ms.toDouble, "ms")
+ }
+ "%.1f %s".format(size, units)
+ }
+}
+
+private case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double)
+
+private object RuntimePercentage {
+ def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = {
+ val denom = totalTime.toDouble
+ val fetchTime = metrics.shuffleReadMetrics.map(_.fetchWaitTime)
+ val fetch = fetchTime.map(_ / denom)
+ val exec = (metrics.executorRunTime - fetchTime.getOrElse(0L)) / denom
+ val other = 1.0 - (exec + fetch.getOrElse(0d))
+ RuntimePercentage(exec, fetch, other)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index d2b8ca90a9..1ff9d7795f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -19,12 +19,13 @@ package org.apache.spark.scheduler
import java.io.{DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
+import java.util.Properties
import scala.collection.mutable.HashMap
import org.apache.spark.{Accumulator, SparkEnv, TaskContext, TaskContextImpl}
import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.memory.TaskMemoryManager
+import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils}
@@ -46,12 +47,14 @@ import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Uti
* @param initialAccumulators initial set of accumulators to be used in this task for tracking
* internal metrics. Other accumulators will be registered later when
* they are deserialized on the executors.
+ * @param localProperties copy of thread-local properties set by the user on the driver side.
*/
private[spark] abstract class Task[T](
val stageId: Int,
val stageAttemptId: Int,
val partitionId: Int,
- val initialAccumulators: Seq[Accumulator[_]]) extends Serializable {
+ val initialAccumulators: Seq[Accumulator[_]],
+ @transient var localProperties: Properties) extends Serializable {
/**
* Called by [[org.apache.spark.executor.Executor]] to run this task.
@@ -71,6 +74,7 @@ private[spark] abstract class Task[T](
taskAttemptId,
attemptNumber,
taskMemoryManager,
+ localProperties,
metricsSystem,
initialAccumulators)
TaskContext.setTaskContext(context)
@@ -80,17 +84,24 @@ private[spark] abstract class Task[T](
}
try {
runTask(context)
- } catch { case e: Throwable =>
- // Catch all errors; run task failure callbacks, and rethrow the exception.
- context.markTaskFailed(e)
- throw e
+ } catch {
+ case e: Throwable =>
+ // Catch all errors; run task failure callbacks, and rethrow the exception.
+ try {
+ context.markTaskFailed(e)
+ } catch {
+ case t: Throwable =>
+ e.addSuppressed(t)
+ }
+ throw e
} finally {
// Call the task completion callbacks.
context.markTaskCompleted()
try {
Utils.tryLogNonFatalError {
// Release memory used by this thread for unrolling blocks
- SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask()
+ SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
+ SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.OFF_HEAP)
// Notify any tasks waiting for execution memory to be freed to wake up and try to
// acquire memory again. This makes impossible the scenario where a task sleeps forever
// because there are no other tasks left to notify it. Since this is safe to do but may
@@ -205,6 +216,11 @@ private[spark] object Task {
dataOut.writeLong(timestamp)
}
+ // Write the task properties separately so it is available before full task deserialization.
+ val propBytes = Utils.serialize(task.localProperties)
+ dataOut.writeInt(propBytes.length)
+ dataOut.write(propBytes)
+
// Write the task itself and finish
dataOut.flush()
val taskBytes = serializer.serialize(task)
@@ -220,7 +236,7 @@ private[spark] object Task {
* @return (taskFiles, taskJars, taskBytes)
*/
def deserializeWithDependencies(serializedTask: ByteBuffer)
- : (HashMap[String, Long], HashMap[String, Long], ByteBuffer) = {
+ : (HashMap[String, Long], HashMap[String, Long], Properties, ByteBuffer) = {
val in = new ByteBufferInputStream(serializedTask)
val dataIn = new DataInputStream(in)
@@ -239,8 +255,13 @@ private[spark] object Task {
taskJars(dataIn.readUTF()) = dataIn.readLong()
}
+ val propLength = dataIn.readInt()
+ val propBytes = new Array[Byte](propLength)
+ dataIn.readFully(propBytes, 0, propLength)
+ val taskProps = Utils.deserialize[Properties](propBytes)
+
// Create a sub-buffer for the rest of the data, which is the serialized Task object
val subBuffer = serializedTask.slice() // ByteBufferInputStream will have read just up to task
- (taskFiles, taskJars, subBuffer)
+ (taskFiles, taskJars, taskProps, subBuffer)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 873f1b56bd..ae7ef46abb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -133,7 +133,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
// if we can't deserialize the reason.
logError(
"Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
- case ex: Exception => {}
+ case ex: Exception => // No-op
}
scheduler.handleFailedTask(taskSetManager, tid, taskState, reason)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 8477a66b39..647d44a0f0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -51,7 +51,7 @@ private[spark] trait TaskScheduler {
def submitTasks(taskSet: TaskSet): Unit
// Cancel a stage.
- def cancelTasks(stageId: Int, interruptThread: Boolean)
+ def cancelTasks(stageId: Int, interruptThread: Boolean): Unit
// Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
def setDAGScheduler(dagScheduler: DAGScheduler): Unit
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index f7790fccc6..c3159188d9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -90,6 +90,8 @@ private[spark] class TaskSchedulerImpl(
// Number of tasks running on each executor
private val executorIdToTaskCount = new HashMap[String, Int]
+ def runningTasksByExecutors(): Map[String, Int] = executorIdToTaskCount.toMap
+
// The set of executors we have on each host; this is used to compute hostsAlive, which
// in turn is used to decide when we can attain data locality on a given host
protected val executorsByHost = new HashMap[String, HashSet[String]]
@@ -569,6 +571,11 @@ private[spark] class TaskSchedulerImpl(
return
}
while (!backend.isReady) {
+ // Might take a while for backend to be ready if it is waiting on resources.
+ if (sc.stopped.get) {
+ // For example: the master removes the application for some reason
+ throw new IllegalStateException("Spark context stopped while waiting for backend")
+ }
synchronized {
this.wait(100)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 15d3515a02..6e08cdd87a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -188,20 +188,18 @@ private[spark] class TaskSetManager(
loc match {
case e: ExecutorCacheTaskLocation =>
pendingTasksForExecutor.getOrElseUpdate(e.executorId, new ArrayBuffer) += index
- case e: HDFSCacheTaskLocation => {
+ case e: HDFSCacheTaskLocation =>
val exe = sched.getExecutorsAliveOnHost(loc.host)
exe match {
- case Some(set) => {
+ case Some(set) =>
for (e <- set) {
pendingTasksForExecutor.getOrElseUpdate(e, new ArrayBuffer) += index
}
logInfo(s"Pending task $index has a cached location at ${e.host} " +
", where there are executors " + set.mkString(","))
- }
case None => logDebug(s"Pending task $index has a cached location at ${e.host} " +
", but there are no executors alive there.")
}
- }
case _ =>
}
pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer) += index
@@ -437,7 +435,7 @@ private[spark] class TaskSetManager(
}
dequeueTask(execId, host, allowedLocality) match {
- case Some((index, taskLocality, speculative)) => {
+ case Some((index, taskLocality, speculative)) =>
// Found a task; do some bookkeeping and return a task description
val task = tasks(index)
val taskId = sched.newTaskId()
@@ -486,7 +484,6 @@ private[spark] class TaskSetManager(
sched.dagScheduler.taskStarted(task, info)
return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId,
taskName, index, serializedTask))
- }
case _ =>
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index 8d5c11dc36..46a829114e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -30,6 +30,8 @@ private[spark] object CoarseGrainedClusterMessages {
case object RetrieveSparkProps extends CoarseGrainedClusterMessage
+ case object RetrieveLastAllocatedExecutorId extends CoarseGrainedClusterMessage
+
// Driver to executors
case class LaunchTask(data: SerializableBuffer) extends CoarseGrainedClusterMessage
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index b7919efc4b..8896391f97 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -19,6 +19,7 @@ package org.apache.spark.scheduler.cluster
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
+import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
@@ -43,24 +44,30 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
extends ExecutorAllocationClient with SchedulerBackend with Logging
{
// Use an atomic variable to track total number of cores in the cluster for simplicity and speed
- var totalCoreCount = new AtomicInteger(0)
+ protected val totalCoreCount = new AtomicInteger(0)
// Total number of executors that are currently registered
- var totalRegisteredExecutors = new AtomicInteger(0)
- val conf = scheduler.sc.conf
+ protected val totalRegisteredExecutors = new AtomicInteger(0)
+ protected val conf = scheduler.sc.conf
private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf)
// Submit tasks only after (registered resources / total expected resources)
// is equal to at least this value, that is double between 0 and 1.
- var minRegisteredRatio =
+ private val _minRegisteredRatio =
math.min(1, conf.getDouble("spark.scheduler.minRegisteredResourcesRatio", 0))
// Submit tasks after maxRegisteredWaitingTime milliseconds
// if minRegisteredRatio has not yet been reached
- val maxRegisteredWaitingTimeMs =
+ private val maxRegisteredWaitingTimeMs =
conf.getTimeAsMs("spark.scheduler.maxRegisteredResourcesWaitingTime", "30s")
- val createTime = System.currentTimeMillis()
+ private val createTime = System.currentTimeMillis()
+ // Accessing `executorDataMap` in `DriverEndpoint.receive/receiveAndReply` doesn't need any
+ // protection. But accessing `executorDataMap` out of `DriverEndpoint.receive/receiveAndReply`
+ // must be protected by `CoarseGrainedSchedulerBackend.this`. Besides, `executorDataMap` should
+ // only be modified in `DriverEndpoint.receive/receiveAndReply` with protection by
+ // `CoarseGrainedSchedulerBackend.this`.
private val executorDataMap = new HashMap[String, ExecutorData]
// Number of executors requested from the cluster manager that have not registered yet
+ @GuardedBy("CoarseGrainedSchedulerBackend.this")
private var numPendingExecutors = 0
private val listenerBus = scheduler.sc.listenerBus
@@ -68,20 +75,26 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
// Executors we have requested the cluster manager to kill that have not died yet; maps
// the executor ID to whether it was explicitly killed by the driver (and thus shouldn't
// be considered an app-related failure).
+ @GuardedBy("CoarseGrainedSchedulerBackend.this")
private val executorsPendingToRemove = new HashMap[String, Boolean]
// A map to store hostname with its possible task number running on it
+ @GuardedBy("CoarseGrainedSchedulerBackend.this")
protected var hostToLocalTaskCount: Map[String, Int] = Map.empty
// The number of pending tasks which is locality required
+ @GuardedBy("CoarseGrainedSchedulerBackend.this")
protected var localityAwareTasks = 0
- // Executors that have been lost, but for which we don't yet know the real exit reason.
- protected val executorsPendingLossReason = new HashSet[String]
+ // The num of current max ExecutorId used to re-register appMaster
+ @volatile protected var currentExecutorIdCounter = 0
class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)])
extends ThreadSafeRpcEndpoint with Logging {
+ // Executors that have been lost, but for which we don't yet know the real exit reason.
+ protected val executorsPendingLossReason = new HashSet[String]
+
// If this DriverEndpoint is changed to support multiple threads,
// then this may need to be changed so that we don't share the serializer
// instance across threads
@@ -137,7 +150,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
case RegisterExecutor(executorId, executorRef, cores, logUrls) =>
if (executorDataMap.contains(executorId)) {
- context.reply(RegisterExecutorFailed("Duplicate executor ID: " + executorId))
+ executorRef.send(RegisterExecutorFailed("Duplicate executor ID: " + executorId))
+ context.reply(true)
} else {
// If the executor's rpc env is not listening for incoming connections, `hostPort`
// will be null, and the client connection should be used to contact the executor.
@@ -156,13 +170,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
// in this block are read when requesting executors
CoarseGrainedSchedulerBackend.this.synchronized {
executorDataMap.put(executorId, data)
+ if (currentExecutorIdCounter < executorId.toInt) {
+ currentExecutorIdCounter = executorId.toInt
+ }
if (numPendingExecutors > 0) {
numPendingExecutors -= 1
logDebug(s"Decremented number of pending executors ($numPendingExecutors left)")
}
}
+ executorRef.send(RegisteredExecutor(executorAddress.host))
// Note: some tests expect the reply to come after we put the executor in the map
- context.reply(RegisteredExecutor(executorAddress.host))
+ context.reply(true)
listenerBus.post(
SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data))
makeOffers()
@@ -255,7 +273,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}
// Remove a disconnected slave from the cluster
- def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = {
+ private def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = {
executorDataMap.get(executorId) match {
case Some(executorInfo) =>
// This must be synchronized because variables mutated
@@ -307,7 +325,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}
var driverEndpoint: RpcEndpointRef = null
- val taskIdsOnSlave = new HashMap[String, HashSet[String]]
+
+ protected def minRegisteredRatio: Double = _minRegisteredRatio
override def start() {
val properties = new ArrayBuffer[(String, String)]
@@ -356,20 +375,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
/**
* Reset the state of CoarseGrainedSchedulerBackend to the initial state. Currently it will only
- * be called in the yarn-client mode when AM re-registers after a failure, also dynamic
- * allocation is enabled.
+ * be called in the yarn-client mode when AM re-registers after a failure.
* */
protected def reset(): Unit = synchronized {
- if (Utils.isDynamicAllocationEnabled(conf)) {
- numPendingExecutors = 0
- executorsPendingToRemove.clear()
-
- // Remove all the lingering executors that should be removed but not yet. The reason might be
- // because (1) disconnected event is not yet received; (2) executors die silently.
- executorDataMap.toMap.foreach { case (eid, _) =>
- driverEndpoint.askWithRetry[Boolean](
- RemoveExecutor(eid, SlaveLost("Stale executor after cluster manager re-registered.")))
- }
+ numPendingExecutors = 0
+ executorsPendingToRemove.clear()
+
+ // Remove all the lingering executors that should be removed but not yet. The reason might be
+ // because (1) disconnected event is not yet received; (2) executors die silently.
+ executorDataMap.toMap.foreach { case (eid, _) =>
+ driverEndpoint.askWithRetry[Boolean](
+ RemoveExecutor(eid, SlaveLost("Stale executor after cluster manager re-registered.")))
}
}
@@ -414,7 +430,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
/**
* Return the number of executors currently registered with this backend.
*/
- def numExistingExecutors: Int = executorDataMap.size
+ private def numExistingExecutors: Int = executorDataMap.size
+
+ override def getExecutorIds(): Seq[String] = {
+ executorDataMap.keySet.toSeq
+ }
/**
* Request an additional number of executors from the cluster manager.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index 90b1813750..50b452c72f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -295,12 +295,12 @@ private[spark] class CoarseMesosSchedulerBackend(
}
/**
- * Launches executors on accepted offers, and declines unused offers. Executors are launched
- * round-robin on offers.
- *
- * @param d SchedulerDriver
- * @param offers Mesos offers that match attribute constraints
- */
+ * Launches executors on accepted offers, and declines unused offers. Executors are launched
+ * round-robin on offers.
+ *
+ * @param d SchedulerDriver
+ * @param offers Mesos offers that match attribute constraints
+ */
private def handleMatchedOffers(d: SchedulerDriver, offers: Buffer[Offer]): Unit = {
val tasks = buildMesosTasks(offers)
for (offer <- offers) {
@@ -336,12 +336,12 @@ private[spark] class CoarseMesosSchedulerBackend(
}
/**
- * Returns a map from OfferIDs to the tasks to launch on those offers. In order to maximize
- * per-task memory and IO, tasks are round-robin assigned to offers.
- *
- * @param offers Mesos offers that match attribute constraints
- * @return A map from OfferID to a list of Mesos tasks to launch on that offer
- */
+ * Returns a map from OfferIDs to the tasks to launch on those offers. In order to maximize
+ * per-task memory and IO, tasks are round-robin assigned to offers.
+ *
+ * @param offers Mesos offers that match attribute constraints
+ * @return A map from OfferID to a list of Mesos tasks to launch on that offer
+ */
private def buildMesosTasks(offers: Buffer[Offer]): Map[OfferID, List[MesosTaskInfo]] = {
// offerID -> tasks
val tasks = new HashMap[OfferID, List[MesosTaskInfo]].withDefaultValue(Nil)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala
index 3971e6c382..61ab3e87c5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala
@@ -121,11 +121,10 @@ private[spark] class ZookeeperMesosClusterPersistenceEngine(
Some(Utils.deserialize[T](fileData))
} catch {
case e: NoNodeException => None
- case e: Exception => {
+ case e: Exception =>
logWarning("Exception while reading persisted file, deleting", e)
zk.delete().forPath(zkPath)
None
- }
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
index 2df7b1120b..73bd4c58e1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
@@ -423,6 +423,12 @@ private[spark] class MesosClusterScheduler(
"--driver-cores", desc.cores.toString,
"--driver-memory", s"${desc.mem}M")
+ val replicatedOptionsBlacklist = Set(
+ "spark.jars", // Avoids duplicate classes in classpath
+ "spark.submit.deployMode", // this would be set to `cluster`, but we need client
+ "spark.master" // this contains the address of the dispatcher, not master
+ )
+
// Assume empty main class means we're running python
if (!desc.command.mainClass.equals("")) {
options ++= Seq("--class", desc.command.mainClass)
@@ -440,9 +446,29 @@ private[spark] class MesosClusterScheduler(
.mkString(",")
options ++= Seq("--py-files", formattedFiles)
}
+ desc.schedulerProperties
+ .filter { case (key, _) => !replicatedOptionsBlacklist.contains(key) }
+ .foreach { case (key, value) => options ++= Seq("--conf", s"$key=${shellEscape(value)}") }
options
}
+ /**
+ * Escape args for Unix-like shells, unless already quoted by the user.
+ * Based on: http://www.gnu.org/software/bash/manual/html_node/Double-Quotes.html
+ * and http://www.grymoire.com/Unix/Quote.html
+ * @param value argument
+ * @return escaped argument
+ */
+ private[scheduler] def shellEscape(value: String): String = {
+ val WrappedInQuotes = """^(".+"|'.+')$""".r
+ val ShellSpecialChars = (""".*([ '<>&|\?\*;!#\\(\)"$`]).*""").r
+ value match {
+ case WrappedInQuotes(c) => value // The user quoted his args, don't touch it!
+ case ShellSpecialChars(c) => "\"" + value.replaceAll("""(["`\$\\])""", """\\$1""") + "\""
+ case _: String => value // Don't touch harmless strings
+ }
+ }
+
private class ResourceOffer(
val offerId: OfferID,
val slaveId: SlaveID,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala
index 374c79a7e5..1b7ac172de 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala
@@ -55,11 +55,10 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging {
Some(vol.setContainerPath(container_path)
.setHostPath(host_path)
.setMode(Volume.Mode.RO))
- case spec => {
+ case spec =>
logWarning(s"Unable to parse volume specs: $volumes. "
+ "Expected form: \"[host-dir:]container-dir[:rw|:ro](, ...)\"")
None
- }
}
}
.map { _.build() }
@@ -90,11 +89,10 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging {
Some(portmap.setHostPort(host_port.toInt)
.setContainerPort(container_port.toInt)
.setProtocol(protocol))
- case spec => {
+ case spec =>
logWarning(s"Unable to parse port mapping specs: $portmaps. "
+ "Expected form: \"host_port:container_port[:udp|:tcp](, ...)\"")
None
- }
}
}
.map { _.build() }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
index 9a12a61f2f..1e322ac679 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
@@ -124,11 +124,10 @@ private[mesos] trait MesosSchedulerUtils extends Logging {
markErr()
}
} catch {
- case e: Exception => {
+ case e: Exception =>
logError("driver.run() failed", e)
error = Some(e)
markErr()
- }
}
}
}.start()
@@ -148,8 +147,8 @@ private[mesos] trait MesosSchedulerUtils extends Logging {
}
/**
- * Signal that the scheduler has registered with Mesos.
- */
+ * Signal that the scheduler has registered with Mesos.
+ */
protected def markRegistered(): Unit = {
registerLatch.countDown()
}
@@ -184,7 +183,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging {
var remain = amountToUse
var requestedResources = new ArrayBuffer[Resource]
val remainingResources = resources.asScala.map {
- case r => {
+ case r =>
if (remain > 0 &&
r.getType == Value.Type.SCALAR &&
r.getScalar.getValue > 0.0 &&
@@ -196,7 +195,6 @@ private[mesos] trait MesosSchedulerUtils extends Logging {
} else {
r
}
- }
}
// Filter any resource that has depleted.
@@ -228,7 +226,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging {
* @return
*/
protected def toAttributeMap(offerAttributes: JList[Attribute]): Map[String, GeneratedMessage] = {
- offerAttributes.asScala.map(attr => {
+ offerAttributes.asScala.map { attr =>
val attrValue = attr.getType match {
case Value.Type.SCALAR => attr.getScalar
case Value.Type.RANGES => attr.getRanges
@@ -236,7 +234,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging {
case Value.Type.TEXT => attr.getText
}
(attr.getName, attrValue)
- }).toMap
+ }.toMap
}
@@ -283,11 +281,11 @@ private[mesos] trait MesosSchedulerUtils extends Logging {
* are separated by ':'. The ':' implies equality (for singular values) and "is one of" for
* multiple values (comma separated). For example:
* {{{
- * parseConstraintString("tachyon:true;zone:us-east-1a,us-east-1b")
+ * parseConstraintString("os:centos7;zone:us-east-1a,us-east-1b")
* // would result in
* <code>
* Map(
- * "tachyon" -> Set("true"),
+ * "os" -> Set("centos7"),
* "zone": -> Set("us-east-1a", "us-east-1b")
* )
* }}}
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 3d090a4353..918ae376f6 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -357,7 +357,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
* serialization.
*/
trait KryoRegistrator {
- def registerClasses(kryo: Kryo)
+ def registerClasses(kryo: Kryo): Unit
}
private[serializer] object KryoSerializer {
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index 5ead40e89e..cb95246d5b 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -188,10 +188,9 @@ abstract class DeserializationStream {
try {
(readKey[Any](), readValue[Any]())
} catch {
- case eof: EOFException => {
+ case eof: EOFException =>
finished = true
null
- }
}
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
index 27e5fa4c2b..745ef12691 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
@@ -25,7 +25,7 @@ import scala.reflect.ClassTag
import org.apache.spark.SparkConf
import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage._
-import org.apache.spark.util.io.{ByteArrayChunkOutputStream, ChunkedByteBuffer}
+import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
/**
* Component which configures serialization and compression for various Spark components, including
@@ -128,17 +128,9 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
/** Serializes into a chunked byte buffer. */
def dataSerialize[T: ClassTag](blockId: BlockId, values: Iterator[T]): ChunkedByteBuffer = {
- val byteArrayChunkOutputStream = new ByteArrayChunkOutputStream(1024 * 1024 * 4)
- dataSerializeStream(blockId, byteArrayChunkOutputStream, values)
- new ChunkedByteBuffer(byteArrayChunkOutputStream.toArrays.map(ByteBuffer.wrap))
- }
-
- /**
- * Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of
- * the iterator is reached.
- */
- def dataDeserialize[T: ClassTag](blockId: BlockId, bytes: ChunkedByteBuffer): Iterator[T] = {
- dataDeserializeStream[T](blockId, bytes.toInputStream(dispose = true))
+ val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate)
+ dataSerializeStream(blockId, bbos, values)
+ bbos.toChunkedByteBuffer
}
/**
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index 637b2dfc19..876cdfaa87 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -69,10 +69,10 @@ private[spark] class BlockStoreShuffleReader[K, C](
// Update the context task metrics for each record read.
val readMetrics = context.taskMetrics.registerTempShuffleReadMetrics()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
- recordIter.map(record => {
+ recordIter.map { record =>
readMetrics.incRecordsRead(1)
record
- }),
+ },
context.taskMetrics().mergeShuffleReadMetrics())
// An interruptible iterator must be used here in order to support task cancellation
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
index 6cd7d69518..be1e84a2ba 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
@@ -35,7 +35,7 @@ private[spark] trait ShuffleWriterGroup {
val writers: Array[DiskBlockObjectWriter]
/** @param success Indicates all writes were successful. If false, no blocks will be recorded. */
- def releaseWriters(success: Boolean)
+ def releaseWriters(success: Boolean): Unit
}
/**
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
index 76fd249fbd..364fad664e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
@@ -54,9 +54,9 @@ private[spark] trait ShuffleManager {
context: TaskContext): ShuffleReader[K, C]
/**
- * Remove a shuffle's metadata from the ShuffleManager.
- * @return true if the metadata removed successfully, otherwise false.
- */
+ * Remove a shuffle's metadata from the ShuffleManager.
+ * @return true if the metadata removed successfully, otherwise false.
+ */
def unregisterShuffle(shuffleId: Int): Boolean
/**
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala
index 9c92a50150..f8d6e9fbbb 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala
@@ -147,7 +147,7 @@ private[v1] object AllStagesResource {
speculative = uiData.taskInfo.speculative,
accumulatorUpdates = uiData.taskInfo.accumulables.map { convertAccumulableInfo },
errorMessage = uiData.errorMessage,
- taskMetrics = uiData.taskMetrics.map { convertUiTaskMetrics }
+ taskMetrics = uiData.metrics.map { convertUiTaskMetrics }
)
}
@@ -155,7 +155,7 @@ private[v1] object AllStagesResource {
allTaskData: Iterable[TaskUIData],
quantiles: Array[Double]): TaskMetricDistributions = {
- val rawMetrics = allTaskData.flatMap{_.taskMetrics}.toSeq
+ val rawMetrics = allTaskData.flatMap{_.metrics}.toSeq
def metricQuantiles(f: InternalTaskMetrics => Double): IndexedSeq[Double] =
Distribution(rawMetrics.map { d => f(d) }).get.getQuantiles(quantiles)
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala
index 50b6ba67e9..ba9cd711f1 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala
@@ -177,6 +177,12 @@ private[v1] class ApiRootResource extends UIRootFromServletContext {
@PathParam("attemptId") attemptId: String): EventLogDownloadResource = {
new EventLogDownloadResource(uiRoot, appId, Some(attemptId))
}
+
+ @Path("version")
+ def getVersion(): VersionResource = {
+ new VersionResource(uiRoot)
+ }
+
}
private[spark] object ApiRootResource {
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/VersionResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/VersionResource.scala
new file mode 100644
index 0000000000..673da1ce36
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/VersionResource.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.status.api.v1
+
+import javax.ws.rs._
+import javax.ws.rs.core.MediaType
+
+@Produces(Array(MediaType.APPLICATION_JSON))
+private[v1] class VersionResource(ui: UIRoot) {
+
+ @GET
+ def getVersionInfo(): VersionInfo = new VersionInfo(
+ org.apache.spark.SPARK_VERSION
+ )
+
+}
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
index 909dd0c07e..ebbbf48148 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
@@ -38,7 +38,11 @@ class ApplicationAttemptInfo private[spark](
val lastUpdated: Date,
val duration: Long,
val sparkUser: String,
- val completed: Boolean = false)
+ val completed: Boolean = false) {
+ def getStartTimeEpoch: Long = startTime.getTime
+ def getEndTimeEpoch: Long = endTime.getTime
+ def getLastUpdatedEpoch: Long = lastUpdated.getTime
+}
class ExecutorStageSummary private[spark](
val taskTime : Long,
@@ -237,3 +241,6 @@ class AccumulableInfo private[spark](
val name: String,
val update: Option[String],
val value: String)
+
+class VersionInfo private[spark](
+ val spark: String)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 0c7763f236..35a6c63ad1 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -18,6 +18,7 @@
package org.apache.spark.storage
import java.io._
+import java.nio.ByteBuffer
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.concurrent.{Await, ExecutionContext, Future}
@@ -39,6 +40,7 @@ import org.apache.spark.rpc.RpcEnv
import org.apache.spark.serializer.{SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.storage.memory._
+import org.apache.spark.unsafe.Platform
import org.apache.spark.util._
import org.apache.spark.util.io.ChunkedByteBuffer
@@ -372,8 +374,12 @@ private[spark] class BlockManager(
val onDisk = level.useDisk && diskStore.contains(blockId)
val deserialized = if (inMem) level.deserialized else false
val replication = if (inMem || onDisk) level.replication else 1
- val storageLevel =
- StorageLevel(onDisk, inMem, deserialized, replication)
+ val storageLevel = StorageLevel(
+ useDisk = onDisk,
+ useMemory = inMem,
+ useOffHeap = level.useOffHeap,
+ deserialized = deserialized,
+ replication = replication)
val memSize = if (inMem) memoryStore.getSize(blockId) else 0L
val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L
BlockStatus(storageLevel, memSize, diskSize)
@@ -407,8 +413,8 @@ private[spark] class BlockManager(
val iter: Iterator[Any] = if (level.deserialized) {
memoryStore.getValues(blockId).get
} else {
- serializerManager.dataDeserialize(
- blockId, memoryStore.getBytes(blockId).get)(info.classTag)
+ serializerManager.dataDeserializeStream(
+ blockId, memoryStore.getBytes(blockId).get.toInputStream())(info.classTag)
}
val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId))
Some(new BlockResult(ci, DataReadMethod.Memory, info.size))
@@ -416,11 +422,15 @@ private[spark] class BlockManager(
val iterToReturn: Iterator[Any] = {
val diskBytes = diskStore.getBytes(blockId)
if (level.deserialized) {
- val diskValues = serializerManager.dataDeserialize(blockId, diskBytes)(info.classTag)
+ val diskValues = serializerManager.dataDeserializeStream(
+ blockId,
+ diskBytes.toInputStream(dispose = true))(info.classTag)
maybeCacheDiskValuesInMemory(info, blockId, level, diskValues)
} else {
- val bytes = maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes)
- serializerManager.dataDeserialize(blockId, bytes)(info.classTag)
+ val stream = maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes)
+ .map {_.toInputStream(dispose = false)}
+ .getOrElse { diskBytes.toInputStream(dispose = true) }
+ serializerManager.dataDeserializeStream(blockId, stream)(info.classTag)
}
}
val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, releaseLock(blockId))
@@ -481,7 +491,8 @@ private[spark] class BlockManager(
if (level.useMemory && memoryStore.contains(blockId)) {
memoryStore.getBytes(blockId).get
} else if (level.useDisk && diskStore.contains(blockId)) {
- maybeCacheDiskBytesInMemory(info, blockId, level, diskStore.getBytes(blockId))
+ val diskBytes = diskStore.getBytes(blockId)
+ maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes).getOrElse(diskBytes)
} else {
releaseLock(blockId)
throw new SparkException(s"Block $blockId was not found even though it's read-locked")
@@ -496,8 +507,9 @@ private[spark] class BlockManager(
*/
private def getRemoteValues(blockId: BlockId): Option[BlockResult] = {
getRemoteBytes(blockId).map { data =>
- new BlockResult(
- serializerManager.dataDeserialize(blockId, data), DataReadMethod.Network, data.size)
+ val values =
+ serializerManager.dataDeserializeStream(blockId, data.toInputStream(dispose = true))
+ new BlockResult(values, DataReadMethod.Network, data.size)
}
}
@@ -631,6 +643,14 @@ private[spark] class BlockManager(
level: StorageLevel,
classTag: ClassTag[T],
makeIterator: () => Iterator[T]): Either[BlockResult, Iterator[T]] = {
+ // Attempt to read the block from local or remote storage. If it's present, then we don't need
+ // to go through the local-get-or-put path.
+ get(blockId) match {
+ case Some(block) =>
+ return Left(block)
+ case _ =>
+ // Need to compute the block.
+ }
// Initially we hold no locks on this block.
doPutIterator(blockId, makeIterator, level, classTag, keepReadLock = true) match {
case None =>
@@ -745,7 +765,8 @@ private[spark] class BlockManager(
// Put it in memory first, even if it also has useDisk set to true;
// We will drop it to disk later if the memory store can't hold it.
val putSucceeded = if (level.deserialized) {
- val values = serializerManager.dataDeserialize(blockId, bytes)(classTag)
+ val values =
+ serializerManager.dataDeserializeStream(blockId, bytes.toInputStream())(classTag)
memoryStore.putIteratorAsValues(blockId, values, classTag) match {
case Right(_) => true
case Left(iter) =>
@@ -755,7 +776,7 @@ private[spark] class BlockManager(
false
}
} else {
- memoryStore.putBytes(blockId, size, () => bytes)
+ memoryStore.putBytes(blockId, size, level.memoryMode, () => bytes)
}
if (!putSucceeded && level.useDisk) {
logWarning(s"Persisting block $blockId to disk instead.")
@@ -893,7 +914,7 @@ private[spark] class BlockManager(
}
}
} else { // !level.deserialized
- memoryStore.putIteratorAsBytes(blockId, iterator(), classTag) match {
+ memoryStore.putIteratorAsBytes(blockId, iterator(), classTag, level.memoryMode) match {
case Right(s) =>
size = s
case Left(partiallySerializedValues) =>
@@ -951,14 +972,16 @@ private[spark] class BlockManager(
* Attempts to cache spilled bytes read from disk into the MemoryStore in order to speed up
* subsequent reads. This method requires the caller to hold a read lock on the block.
*
- * @return a copy of the bytes. The original bytes passed this method should no longer
- * be used after this method returns.
+ * @return a copy of the bytes from the memory store if the put succeeded, otherwise None.
+ * If this returns bytes from the memory store then the original disk store bytes will
+ * automatically be disposed and the caller should not continue to use them. Otherwise,
+ * if this returns None then the original disk store bytes will be unaffected.
*/
private def maybeCacheDiskBytesInMemory(
blockInfo: BlockInfo,
blockId: BlockId,
level: StorageLevel,
- diskBytes: ChunkedByteBuffer): ChunkedByteBuffer = {
+ diskBytes: ChunkedByteBuffer): Option[ChunkedByteBuffer] = {
require(!level.deserialized)
if (level.useMemory) {
// Synchronize on blockInfo to guard against a race condition where two readers both try to
@@ -966,25 +989,29 @@ private[spark] class BlockManager(
blockInfo.synchronized {
if (memoryStore.contains(blockId)) {
diskBytes.dispose()
- memoryStore.getBytes(blockId).get
+ Some(memoryStore.getBytes(blockId).get)
} else {
- val putSucceeded = memoryStore.putBytes(blockId, diskBytes.size, () => {
+ val allocator = level.memoryMode match {
+ case MemoryMode.ON_HEAP => ByteBuffer.allocate _
+ case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _
+ }
+ val putSucceeded = memoryStore.putBytes(blockId, diskBytes.size, level.memoryMode, () => {
// https://issues.apache.org/jira/browse/SPARK-6076
// If the file size is bigger than the free memory, OOM will happen. So if we
// cannot put it into MemoryStore, copyForMemory should not be created. That's why
// this action is put into a `() => ChunkedByteBuffer` and created lazily.
- diskBytes.copy()
+ diskBytes.copy(allocator)
})
if (putSucceeded) {
diskBytes.dispose()
- memoryStore.getBytes(blockId).get
+ Some(memoryStore.getBytes(blockId).get)
} else {
- diskBytes
+ None
}
}
}
} else {
- diskBytes
+ None
}
}
@@ -1055,7 +1082,12 @@ private[spark] class BlockManager(
val peersForReplication = new ArrayBuffer[BlockManagerId]
val peersReplicatedTo = new ArrayBuffer[BlockManagerId]
val peersFailedToReplicateTo = new ArrayBuffer[BlockManagerId]
- val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
+ val tLevel = StorageLevel(
+ useDisk = level.useDisk,
+ useMemory = level.useMemory,
+ useOffHeap = level.useOffHeap,
+ deserialized = level.deserialized,
+ replication = 1)
val startTime = System.currentTimeMillis
val random = new Random(blockId.hashCode)
@@ -1264,9 +1296,12 @@ private[spark] class BlockManager(
"the disk, memory, or external block store")
}
blockInfoManager.removeBlock(blockId)
+ val removeBlockStatus = getCurrentBlockStatus(blockId, info)
if (tellMaster && info.tellMaster) {
- val status = getCurrentBlockStatus(blockId, info)
- reportBlockStatus(blockId, info, status)
+ reportBlockStatus(blockId, info, removeBlockStatus)
+ }
+ Option(TaskContext.get()).foreach { c =>
+ c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, removeBlockStatus)))
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
index d2a5c69e15..8fa1215011 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
@@ -453,7 +453,7 @@ private[spark] class BlockManagerInfo(
}
if (storageLevel.isValid) {
- /* isValid means it is either stored in-memory, on-disk or on-externalBlockStore.
+ /* isValid means it is either stored in-memory or on-disk.
* The memSize here indicates the data size in or dropped from memory,
* externalBlockStoreSize here indicates the data size in or dropped from externalBlockStore,
* and the diskSize here indicates the data size in or dropped to disk.
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index 25edb9f1e4..4ec5b4bbb0 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -143,13 +143,12 @@ final class ShuffleBlockFetcherIterator(
while (iter.hasNext) {
val result = iter.next()
result match {
- case SuccessFetchResult(_, address, _, buf, _) => {
+ case SuccessFetchResult(_, address, _, buf, _) =>
if (address != blockManager.blockManagerId) {
shuffleMetrics.incRemoteBytesRead(buf.size)
shuffleMetrics.incRemoteBlocksFetched(1)
}
buf.release()
- }
case _ =>
}
}
@@ -313,7 +312,7 @@ final class ShuffleBlockFetcherIterator(
shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
result match {
- case SuccessFetchResult(_, address, size, buf, isNetworkReqDone) => {
+ case SuccessFetchResult(_, address, size, buf, isNetworkReqDone) =>
if (address != blockManager.blockManagerId) {
shuffleMetrics.incRemoteBytesRead(buf.size)
shuffleMetrics.incRemoteBlocksFetched(1)
@@ -323,7 +322,6 @@ final class ShuffleBlockFetcherIterator(
reqsInFlight -= 1
logDebug("Number of requests in flight " + reqsInFlight)
}
- }
case _ =>
}
// Send fetch requests up to maxBytesInFlight
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
index 7d23295e25..216ec07934 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
@@ -60,10 +60,7 @@ class StorageLevel private(
assert(replication < 40, "Replication restricted to be less than 40 for calculating hash codes")
if (useOffHeap) {
- require(!useDisk, "Off-heap storage level does not support using disk")
- require(!useMemory, "Off-heap storage level does not support using heap memory")
require(!deserialized, "Off-heap storage level does not support deserialized storage")
- require(replication == 1, "Off-heap storage level does not support multiple replication")
}
private[spark] def memoryMode: MemoryMode = {
@@ -86,7 +83,7 @@ class StorageLevel private(
false
}
- def isValid: Boolean = (useMemory || useDisk || useOffHeap) && (replication > 0)
+ def isValid: Boolean = (useMemory || useDisk) && (replication > 0)
def toInt: Int = {
var ret = 0
@@ -123,7 +120,8 @@ class StorageLevel private(
private def readResolve(): Object = StorageLevel.getCachedStorageLevel(this)
override def toString: String = {
- s"StorageLevel($useDisk, $useMemory, $useOffHeap, $deserialized, $replication)"
+ s"StorageLevel(disk=$useDisk, memory=$useMemory, offheap=$useOffHeap, " +
+ s"deserialized=$deserialized, replication=$replication)"
}
override def hashCode(): Int = toInt * 41 + replication
@@ -131,8 +129,9 @@ class StorageLevel private(
def description: String = {
var result = ""
result += (if (useDisk) "Disk " else "")
- result += (if (useMemory) "Memory " else "")
- result += (if (useOffHeap) "ExternalBlockStore " else "")
+ if (useMemory) {
+ result += (if (useOffHeap) "Memory (off heap) " else "Memory ")
+ }
result += (if (deserialized) "Deserialized " else "Serialized ")
result += s"${replication}x Replicated"
result
@@ -156,9 +155,7 @@ object StorageLevel {
val MEMORY_AND_DISK_2 = new StorageLevel(true, true, false, true, 2)
val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, false)
val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, false, 2)
-
- // Redirect to MEMORY_ONLY_SER for now.
- val OFF_HEAP = MEMORY_ONLY_SER
+ val OFF_HEAP = new StorageLevel(true, true, true, false, 1)
/**
* :: DeveloperApi ::
@@ -183,7 +180,7 @@ object StorageLevel {
/**
* :: DeveloperApi ::
- * Create a new StorageLevel object without setting useOffHeap.
+ * Create a new StorageLevel object.
*/
@DeveloperApi
def apply(
@@ -198,7 +195,7 @@ object StorageLevel {
/**
* :: DeveloperApi ::
- * Create a new StorageLevel object.
+ * Create a new StorageLevel object without setting useOffHeap.
*/
@DeveloperApi
def apply(
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
index f552b498a7..3008520f61 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
@@ -66,17 +66,6 @@ class StorageStatusListener(conf: SparkConf) extends SparkListener {
}
}
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
- val info = taskEnd.taskInfo
- val metrics = taskEnd.taskMetrics
- if (info != null && metrics != null) {
- val updatedBlocks = metrics.updatedBlockStatuses
- if (updatedBlocks.length > 0) {
- updateStorageStatus(info.executorId, updatedBlocks)
- }
- }
- }
-
override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = synchronized {
updateStorageStatus(unpersistRDD.rddId)
}
@@ -102,4 +91,14 @@ class StorageStatusListener(conf: SparkConf) extends SparkListener {
}
}
}
+
+ override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = {
+ val executorId = blockUpdated.blockUpdatedInfo.blockManagerId.executorId
+ val blockId = blockUpdated.blockUpdatedInfo.blockId
+ val storageLevel = blockUpdated.blockUpdatedInfo.storageLevel
+ val memSize = blockUpdated.blockUpdatedInfo.memSize
+ val diskSize = blockUpdated.blockUpdatedInfo.diskSize
+ val blockStatus = BlockStatus(storageLevel, memSize, diskSize)
+ updateStorageStatus(executorId, Seq((blockId, blockStatus)))
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
index 199a5fc270..fb9941bbd9 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
@@ -175,7 +175,10 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) {
def memRemaining: Long = maxMem - memUsed
/** Return the memory used by this block manager. */
- def memUsed: Long = _nonRddStorageInfo._1 + _rddBlocks.keys.toSeq.map(memUsedByRdd).sum
+ def memUsed: Long = _nonRddStorageInfo._1 + cacheSize
+
+ /** Return the memory used by caching RDDs */
+ def cacheSize: Long = _rddBlocks.keys.toSeq.map(memUsedByRdd).sum
/** Return the disk space used by this block manager. */
def diskUsed: Long = _nonRddStorageInfo._2 + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum
diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
index 3ca41f32c1..99be4de065 100644
--- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
@@ -32,20 +32,25 @@ import org.apache.spark.internal.Logging
import org.apache.spark.memory.{MemoryManager, MemoryMode}
import org.apache.spark.serializer.{SerializationStream, SerializerManager}
import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel}
+import org.apache.spark.unsafe.Platform
import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils}
import org.apache.spark.util.collection.SizeTrackingVector
-import org.apache.spark.util.io.{ByteArrayChunkOutputStream, ChunkedByteBuffer}
+import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
private sealed trait MemoryEntry[T] {
def size: Long
+ def memoryMode: MemoryMode
def classTag: ClassTag[T]
}
private case class DeserializedMemoryEntry[T](
value: Array[T],
size: Long,
- classTag: ClassTag[T]) extends MemoryEntry[T]
+ classTag: ClassTag[T]) extends MemoryEntry[T] {
+ val memoryMode: MemoryMode = MemoryMode.ON_HEAP
+}
private case class SerializedMemoryEntry[T](
buffer: ChunkedByteBuffer,
+ memoryMode: MemoryMode,
classTag: ClassTag[T]) extends MemoryEntry[T] {
def size: Long = buffer.size
}
@@ -86,7 +91,10 @@ private[spark] class MemoryStore(
// A mapping from taskAttemptId to amount of memory used for unrolling a block (in bytes)
// All accesses of this map are assumed to have manually synchronized on `memoryManager`
- private val unrollMemoryMap = mutable.HashMap[Long, Long]()
+ private val onHeapUnrollMemoryMap = mutable.HashMap[Long, Long]()
+ // Note: off-heap unroll memory is only used in putIteratorAsBytes() because off-heap caching
+ // always stores serialized values.
+ private val offHeapUnrollMemoryMap = mutable.HashMap[Long, Long]()
// Initial memory to request before unrolling any block
private val unrollMemoryThreshold: Long =
@@ -131,13 +139,14 @@ private[spark] class MemoryStore(
def putBytes[T: ClassTag](
blockId: BlockId,
size: Long,
+ memoryMode: MemoryMode,
_bytes: () => ChunkedByteBuffer): Boolean = {
require(!contains(blockId), s"Block $blockId is already present in the MemoryStore")
- if (memoryManager.acquireStorageMemory(blockId, size, MemoryMode.ON_HEAP)) {
+ if (memoryManager.acquireStorageMemory(blockId, size, memoryMode)) {
// We acquired enough memory for the block, so go ahead and put it
val bytes = _bytes()
assert(bytes.size == size)
- val entry = new SerializedMemoryEntry[T](bytes, implicitly[ClassTag[T]])
+ val entry = new SerializedMemoryEntry[T](bytes, memoryMode, implicitly[ClassTag[T]])
entries.synchronized {
entries.put(blockId, entry)
}
@@ -190,7 +199,8 @@ private[spark] class MemoryStore(
var vector = new SizeTrackingVector[T]()(classTag)
// Request enough memory to begin unrolling
- keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold)
+ keepUnrolling =
+ reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, MemoryMode.ON_HEAP)
if (!keepUnrolling) {
logWarning(s"Failed to reserve initial memory threshold of " +
@@ -207,7 +217,8 @@ private[spark] class MemoryStore(
val currentSize = vector.estimateSize()
if (currentSize >= memoryThreshold) {
val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong
- keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest)
+ keepUnrolling =
+ reserveUnrollMemoryForThisTask(blockId, amountToRequest, MemoryMode.ON_HEAP)
if (keepUnrolling) {
unrollMemoryUsedByThisBlock += amountToRequest
}
@@ -228,7 +239,7 @@ private[spark] class MemoryStore(
def transferUnrollToStorage(amount: Long): Unit = {
// Synchronize so that transfer is atomic
memoryManager.synchronized {
- releaseUnrollMemoryForThisTask(amount)
+ releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, amount)
val success = memoryManager.acquireStorageMemory(blockId, amount, MemoryMode.ON_HEAP)
assert(success, "transferring unroll memory to storage memory failed")
}
@@ -247,7 +258,7 @@ private[spark] class MemoryStore(
// If this task attempt already owns more unroll memory than is necessary to store the
// block, then release the extra memory that will not be used.
val excessUnrollMemory = unrollMemoryUsedByThisBlock - size
- releaseUnrollMemoryForThisTask(excessUnrollMemory)
+ releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, excessUnrollMemory)
transferUnrollToStorage(size)
true
}
@@ -295,10 +306,16 @@ private[spark] class MemoryStore(
private[storage] def putIteratorAsBytes[T](
blockId: BlockId,
values: Iterator[T],
- classTag: ClassTag[T]): Either[PartiallySerializedBlock[T], Long] = {
+ classTag: ClassTag[T],
+ memoryMode: MemoryMode): Either[PartiallySerializedBlock[T], Long] = {
require(!contains(blockId), s"Block $blockId is already present in the MemoryStore")
+ val allocator = memoryMode match {
+ case MemoryMode.ON_HEAP => ByteBuffer.allocate _
+ case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _
+ }
+
// Whether there is still enough memory for us to continue unrolling this block
var keepUnrolling = true
// Initial per-task memory to request for unrolling blocks (bytes).
@@ -307,15 +324,15 @@ private[spark] class MemoryStore(
var unrollMemoryUsedByThisBlock = 0L
// Underlying buffer for unrolling the block
val redirectableStream = new RedirectableOutputStream
- val byteArrayChunkOutputStream = new ByteArrayChunkOutputStream(initialMemoryThreshold.toInt)
- redirectableStream.setOutputStream(byteArrayChunkOutputStream)
+ val bbos = new ChunkedByteBufferOutputStream(initialMemoryThreshold.toInt, allocator)
+ redirectableStream.setOutputStream(bbos)
val serializationStream: SerializationStream = {
val ser = serializerManager.getSerializer(classTag).newInstance()
ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream))
}
// Request enough memory to begin unrolling
- keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold)
+ keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, memoryMode)
if (!keepUnrolling) {
logWarning(s"Failed to reserve initial memory threshold of " +
@@ -325,9 +342,9 @@ private[spark] class MemoryStore(
}
def reserveAdditionalMemoryIfNecessary(): Unit = {
- if (byteArrayChunkOutputStream.size > unrollMemoryUsedByThisBlock) {
- val amountToRequest = byteArrayChunkOutputStream.size - unrollMemoryUsedByThisBlock
- keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest)
+ if (bbos.size > unrollMemoryUsedByThisBlock) {
+ val amountToRequest = bbos.size - unrollMemoryUsedByThisBlock
+ keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode)
if (keepUnrolling) {
unrollMemoryUsedByThisBlock += amountToRequest
}
@@ -349,12 +366,11 @@ private[spark] class MemoryStore(
}
if (keepUnrolling) {
- val entry = SerializedMemoryEntry[T](
- new ChunkedByteBuffer(byteArrayChunkOutputStream.toArrays.map(ByteBuffer.wrap)), classTag)
+ val entry = SerializedMemoryEntry[T](bbos.toChunkedByteBuffer, memoryMode, classTag)
// Synchronize so that transfer is atomic
memoryManager.synchronized {
- releaseUnrollMemoryForThisTask(unrollMemoryUsedByThisBlock)
- val success = memoryManager.acquireStorageMemory(blockId, entry.size, MemoryMode.ON_HEAP)
+ releaseUnrollMemoryForThisTask(memoryMode, unrollMemoryUsedByThisBlock)
+ val success = memoryManager.acquireStorageMemory(blockId, entry.size, memoryMode)
assert(success, "transferring unroll memory to storage memory failed")
}
entries.synchronized {
@@ -365,7 +381,7 @@ private[spark] class MemoryStore(
Right(entry.size)
} else {
// We ran out of space while unrolling the values for this block
- logUnrollFailureMessage(blockId, byteArrayChunkOutputStream.size)
+ logUnrollFailureMessage(blockId, bbos.size)
Left(
new PartiallySerializedBlock(
this,
@@ -374,7 +390,8 @@ private[spark] class MemoryStore(
serializationStream,
redirectableStream,
unrollMemoryUsedByThisBlock,
- new ChunkedByteBuffer(byteArrayChunkOutputStream.toArrays.map(ByteBuffer.wrap)),
+ memoryMode,
+ bbos.toChunkedByteBuffer,
values,
classTag))
}
@@ -386,7 +403,7 @@ private[spark] class MemoryStore(
case null => None
case e: DeserializedMemoryEntry[_] =>
throw new IllegalArgumentException("should only call getBytes on serialized blocks")
- case SerializedMemoryEntry(bytes, _) => Some(bytes)
+ case SerializedMemoryEntry(bytes, _, _) => Some(bytes)
}
}
@@ -407,8 +424,12 @@ private[spark] class MemoryStore(
entries.remove(blockId)
}
if (entry != null) {
- memoryManager.releaseStorageMemory(entry.size, MemoryMode.ON_HEAP)
- logInfo(s"Block $blockId of size ${entry.size} dropped " +
+ entry match {
+ case SerializedMemoryEntry(buffer, _, _) => buffer.dispose()
+ case _ =>
+ }
+ memoryManager.releaseStorageMemory(entry.size, entry.memoryMode)
+ logDebug(s"Block $blockId of size ${entry.size} dropped " +
s"from memory (free ${maxMemory - blocksMemoryUsed})")
true
} else {
@@ -420,7 +441,8 @@ private[spark] class MemoryStore(
entries.synchronized {
entries.clear()
}
- unrollMemoryMap.clear()
+ onHeapUnrollMemoryMap.clear()
+ offHeapUnrollMemoryMap.clear()
memoryManager.releaseAllStorageMemory()
logInfo("MemoryStore cleared")
}
@@ -433,23 +455,27 @@ private[spark] class MemoryStore(
}
/**
- * Try to evict blocks to free up a given amount of space to store a particular block.
- * Can fail if either the block is bigger than our memory or it would require replacing
- * another block from the same RDD (which leads to a wasteful cyclic replacement pattern for
- * RDDs that don't fit into memory that we want to avoid).
- *
- * @param blockId the ID of the block we are freeing space for, if any
- * @param space the size of this block
- * @return the amount of memory (in bytes) freed by eviction
- */
- private[spark] def evictBlocksToFreeSpace(blockId: Option[BlockId], space: Long): Long = {
+ * Try to evict blocks to free up a given amount of space to store a particular block.
+ * Can fail if either the block is bigger than our memory or it would require replacing
+ * another block from the same RDD (which leads to a wasteful cyclic replacement pattern for
+ * RDDs that don't fit into memory that we want to avoid).
+ *
+ * @param blockId the ID of the block we are freeing space for, if any
+ * @param space the size of this block
+ * @param memoryMode the type of memory to free (on- or off-heap)
+ * @return the amount of memory (in bytes) freed by eviction
+ */
+ private[spark] def evictBlocksToFreeSpace(
+ blockId: Option[BlockId],
+ space: Long,
+ memoryMode: MemoryMode): Long = {
assert(space > 0)
memoryManager.synchronized {
var freedMemory = 0L
val rddToAdd = blockId.flatMap(getRddId)
val selectedBlocks = new ArrayBuffer[BlockId]
- def blockIsEvictable(blockId: BlockId): Boolean = {
- rddToAdd.isEmpty || rddToAdd != getRddId(blockId)
+ def blockIsEvictable(blockId: BlockId, entry: MemoryEntry[_]): Boolean = {
+ entry.memoryMode == memoryMode && (rddToAdd.isEmpty || rddToAdd != getRddId(blockId))
}
// This is synchronized to ensure that the set of entries is not changed
// (because of getValue or getBytes) while traversing the iterator, as that
@@ -459,7 +485,8 @@ private[spark] class MemoryStore(
while (freedMemory < space && iterator.hasNext) {
val pair = iterator.next()
val blockId = pair.getKey
- if (blockIsEvictable(blockId)) {
+ val entry = pair.getValue
+ if (blockIsEvictable(blockId, entry)) {
// We don't want to evict blocks which are currently being read, so we need to obtain
// an exclusive write lock on blocks which are candidates for eviction. We perform a
// non-blocking "tryLock" here in order to ignore blocks which are locked for reading:
@@ -474,7 +501,7 @@ private[spark] class MemoryStore(
def dropBlock[T](blockId: BlockId, entry: MemoryEntry[T]): Unit = {
val data = entry match {
case DeserializedMemoryEntry(values, _, _) => Left(values)
- case SerializedMemoryEntry(buffer, _) => Right(buffer)
+ case SerializedMemoryEntry(buffer, _, _) => Right(buffer)
}
val newEffectiveStorageLevel =
blockEvictionHandler.dropFromMemory(blockId, () => data)(entry.classTag)
@@ -530,11 +557,18 @@ private[spark] class MemoryStore(
*
* @return whether the request is granted.
*/
- def reserveUnrollMemoryForThisTask(blockId: BlockId, memory: Long): Boolean = {
+ def reserveUnrollMemoryForThisTask(
+ blockId: BlockId,
+ memory: Long,
+ memoryMode: MemoryMode): Boolean = {
memoryManager.synchronized {
- val success = memoryManager.acquireUnrollMemory(blockId, memory, MemoryMode.ON_HEAP)
+ val success = memoryManager.acquireUnrollMemory(blockId, memory, memoryMode)
if (success) {
val taskAttemptId = currentTaskAttemptId()
+ val unrollMemoryMap = memoryMode match {
+ case MemoryMode.ON_HEAP => onHeapUnrollMemoryMap
+ case MemoryMode.OFF_HEAP => offHeapUnrollMemoryMap
+ }
unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory
}
success
@@ -545,9 +579,13 @@ private[spark] class MemoryStore(
* Release memory used by this task for unrolling blocks.
* If the amount is not specified, remove the current task's allocation altogether.
*/
- def releaseUnrollMemoryForThisTask(memory: Long = Long.MaxValue): Unit = {
+ def releaseUnrollMemoryForThisTask(memoryMode: MemoryMode, memory: Long = Long.MaxValue): Unit = {
val taskAttemptId = currentTaskAttemptId()
memoryManager.synchronized {
+ val unrollMemoryMap = memoryMode match {
+ case MemoryMode.ON_HEAP => onHeapUnrollMemoryMap
+ case MemoryMode.OFF_HEAP => offHeapUnrollMemoryMap
+ }
if (unrollMemoryMap.contains(taskAttemptId)) {
val memoryToRelease = math.min(memory, unrollMemoryMap(taskAttemptId))
if (memoryToRelease > 0) {
@@ -555,7 +593,7 @@ private[spark] class MemoryStore(
if (unrollMemoryMap(taskAttemptId) == 0) {
unrollMemoryMap.remove(taskAttemptId)
}
- memoryManager.releaseUnrollMemory(memoryToRelease, MemoryMode.ON_HEAP)
+ memoryManager.releaseUnrollMemory(memoryToRelease, memoryMode)
}
}
}
@@ -565,20 +603,23 @@ private[spark] class MemoryStore(
* Return the amount of memory currently occupied for unrolling blocks across all tasks.
*/
def currentUnrollMemory: Long = memoryManager.synchronized {
- unrollMemoryMap.values.sum
+ onHeapUnrollMemoryMap.values.sum + offHeapUnrollMemoryMap.values.sum
}
/**
* Return the amount of memory currently occupied for unrolling blocks by this task.
*/
def currentUnrollMemoryForThisTask: Long = memoryManager.synchronized {
- unrollMemoryMap.getOrElse(currentTaskAttemptId(), 0L)
+ onHeapUnrollMemoryMap.getOrElse(currentTaskAttemptId(), 0L) +
+ offHeapUnrollMemoryMap.getOrElse(currentTaskAttemptId(), 0L)
}
/**
* Return the number of tasks currently unrolling blocks.
*/
- private def numTasksUnrolling: Int = memoryManager.synchronized { unrollMemoryMap.keys.size }
+ private def numTasksUnrolling: Int = memoryManager.synchronized {
+ (onHeapUnrollMemoryMap.keys ++ offHeapUnrollMemoryMap.keys).toSet.size
+ }
/**
* Log information about current memory usage.
@@ -627,7 +668,7 @@ private[storage] class PartiallyUnrolledIterator[T](
private[this] var iter: Iterator[T] = {
val completionIterator = CompletionIterator[T, Iterator[T]](unrolled, {
unrolledIteratorIsConsumed = true
- memoryStore.releaseUnrollMemoryForThisTask(unrollMemory)
+ memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory)
})
completionIterator ++ rest
}
@@ -640,7 +681,7 @@ private[storage] class PartiallyUnrolledIterator[T](
*/
def close(): Unit = {
if (!unrolledIteratorIsConsumed) {
- memoryStore.releaseUnrollMemoryForThisTask(unrollMemory)
+ memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory)
unrolledIteratorIsConsumed = true
}
iter = null
@@ -669,6 +710,7 @@ private class RedirectableOutputStream extends OutputStream {
* @param serializationStream a serialization stream which writes to [[redirectableOutputStream]].
* @param redirectableOutputStream an OutputStream which can be redirected to a different sink.
* @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
+ * @param memoryMode whether the unroll memory is on- or off-heap
* @param unrolled a byte buffer containing the partially-serialized values.
* @param rest the rest of the original iterator passed to
* [[MemoryStore.putIteratorAsValues()]].
@@ -681,18 +723,36 @@ private[storage] class PartiallySerializedBlock[T](
serializationStream: SerializationStream,
redirectableOutputStream: RedirectableOutputStream,
unrollMemory: Long,
+ memoryMode: MemoryMode,
unrolled: ChunkedByteBuffer,
rest: Iterator[T],
classTag: ClassTag[T]) {
+ // If the task does not fully consume `valuesIterator` or otherwise fails to consume or dispose of
+ // this PartiallySerializedBlock then we risk leaking of direct buffers, so we use a task
+ // completion listener here in order to ensure that `unrolled.dispose()` is called at least once.
+ // The dispose() method is idempotent, so it's safe to call it unconditionally.
+ Option(TaskContext.get()).foreach { taskContext =>
+ taskContext.addTaskCompletionListener { _ =>
+ // When a task completes, its unroll memory will automatically be freed. Thus we do not call
+ // releaseUnrollMemoryForThisTask() here because we want to avoid double-freeing.
+ unrolled.dispose()
+ }
+ }
+
/**
* Called to dispose of this block and free its memory.
*/
def discard(): Unit = {
try {
+ // We want to close the output stream in order to free any resources associated with the
+ // serializer itself (such as Kryo's internal buffers). close() might cause data to be
+ // written, so redirect the output stream to discard that data.
+ redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream())
serializationStream.close()
} finally {
- memoryStore.releaseUnrollMemoryForThisTask(unrollMemory)
+ unrolled.dispose()
+ memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
}
}
@@ -701,12 +761,14 @@ private[storage] class PartiallySerializedBlock[T](
* and then serializing the values from the original input iterator.
*/
def finishWritingToStream(os: OutputStream): Unit = {
- ByteStreams.copy(unrolled.toInputStream(), os)
+ // `unrolled`'s underlying buffers will be freed once this input stream is fully read:
+ ByteStreams.copy(unrolled.toInputStream(dispose = true), os)
+ memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
redirectableOutputStream.setOutputStream(os)
while (rest.hasNext) {
serializationStream.writeObject(rest.next())(classTag)
}
- discard()
+ serializationStream.close()
}
/**
@@ -717,10 +779,13 @@ private[storage] class PartiallySerializedBlock[T](
* `close()` on it to free its resources.
*/
def valuesIterator: PartiallyUnrolledIterator[T] = {
+ // `unrolled`'s underlying buffers will be freed once this input stream is fully read:
+ val unrolledIter = serializerManager.dataDeserializeStream(
+ blockId, unrolled.toInputStream(dispose = true))(classTag)
new PartiallyUnrolledIterator(
memoryStore,
unrollMemory,
- unrolled = serializerManager.dataDeserialize(blockId, unrolled)(classTag),
+ unrolled = CompletionIterator[T, Iterator[T]](unrolledIter, discard()),
rest = rest)
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index c3c59f857d..119165f724 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -30,6 +30,7 @@ import org.eclipse.jetty.server.handler._
import org.eclipse.jetty.server.nio.SelectChannelConnector
import org.eclipse.jetty.server.ssl.SslSelectChannelConnector
import org.eclipse.jetty.servlet._
+import org.eclipse.jetty.util.component.LifeCycle
import org.eclipse.jetty.util.thread.QueuedThreadPool
import org.json4s.JValue
import org.json4s.jackson.JsonMethods.{pretty, render}
@@ -350,4 +351,15 @@ private[spark] object JettyUtils extends Logging {
private[spark] case class ServerInfo(
server: Server,
boundPort: Int,
- rootHandler: ContextHandlerCollection)
+ rootHandler: ContextHandlerCollection) {
+
+ def stop(): Unit = {
+ server.stop()
+ // Stop the ThreadPool if it supports stop() method (through LifeCycle).
+ // It is needed because stopping the Server won't stop the ThreadPool it uses.
+ val threadPool = server.getThreadPool
+ if (threadPool != null && threadPool.isInstanceOf[LifeCycle]) {
+ threadPool.asInstanceOf[LifeCycle].stop
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
index 6057522509..39155ff264 100644
--- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
@@ -80,6 +80,10 @@ private[spark] class SparkUI private (
}
initialize()
+ def getSparkUser: String = {
+ environmentListener.systemProperties.toMap.get("user.name").getOrElse("<unknown>")
+ }
+
def getAppName: String = appName
def setAppId(id: String): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala
index 250b7f2e5f..2b0bc32cf6 100644
--- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala
@@ -129,7 +129,7 @@ private[spark] abstract class WebUI(
}
/** Initialize all components of the server. */
- def initialize()
+ def initialize(): Unit
/** Bind to the HTTP server behind this web interface. */
def bind() {
@@ -153,7 +153,7 @@ private[spark] abstract class WebUI(
def stop() {
assert(serverInfo.isDefined,
"Attempted to stop %s before binding to a server!".format(className))
- serverInfo.get.server.stop()
+ serverInfo.get.stop()
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
index cc476d61b5..a0ef80d9bd 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
@@ -38,7 +38,7 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage
val content = maybeThreadDump.map { threadDump =>
val dumpRows = threadDump.sortWith {
- case (threadTrace1, threadTrace2) => {
+ case (threadTrace1, threadTrace2) =>
val v1 = if (threadTrace1.threadName.contains("Executor task launch")) 1 else 0
val v2 = if (threadTrace2.threadName.contains("Executor task launch")) 1 else 0
if (v1 == v2) {
@@ -46,7 +46,6 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage
} else {
v1 > v2
}
- }
}.map { thread =>
val threadId = thread.threadId
<tr id={s"thread_${threadId}_tr"} class="accordion-heading"
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
index 788f35ec77..3fd0efd3a1 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
@@ -70,7 +70,7 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: Spar
executorToLogUrls(eid) = executorAdded.executorInfo.logUrlMap
executorToTotalCores(eid) = executorAdded.executorInfo.totalCores
executorToTasksMax(eid) = executorToTotalCores(eid) / conf.getInt("spark.task.cpus", 1)
- executorIdToData(eid) = ExecutorUIData(executorAdded.time)
+ executorIdToData(eid) = new ExecutorUIData(executorAdded.time)
}
override def onExecutorRemoved(
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
index d1c8b3089a..07484c9550 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
@@ -148,7 +148,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
| 'Removed at ${UIUtils.formatDate(new Date(event.finishTime.get))}' +
| '${
if (event.finishReason.isDefined) {
- s"""<br>Reason: ${event.finishReason.get}"""
+ s"""<br>Reason: ${event.finishReason.get.replace("\n", " ")}"""
} else {
""
}
@@ -297,6 +297,10 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
<div>
<ul class="unstyled">
<li>
+ <strong>User:</strong>
+ {parent.getSparkUser}
+ </li>
+ <li>
<strong>Total Uptime:</strong>
{
if (endTime < 0 && parent.sc.isDefined) {
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
index 1304efd8f2..f609fb4cd2 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
@@ -42,13 +42,13 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage
var hasShuffleWrite = false
var hasShuffleRead = false
var hasBytesSpilled = false
- stageData.foreach(data => {
+ stageData.foreach { data =>
hasInput = data.hasInput
hasOutput = data.hasOutput
hasShuffleRead = data.hasShuffleRead
hasShuffleWrite = data.hasShuffleWrite
hasBytesSpilled = data.hasBytesSpilled
- })
+ }
<table class={UIUtils.TABLE_CLASS_STRIPED_SORTABLE}>
<thead>
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
index 654d988807..bd4797ae8e 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
@@ -122,7 +122,7 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") {
| 'Removed at ${UIUtils.formatDate(new Date(event.finishTime.get))}' +
| '${
if (event.finishReason.isDefined) {
- s"""<br>Reason: ${event.finishReason.get}"""
+ s"""<br>Reason: ${event.finishReason.get.replace("\n", " ")}"""
} else {
""
}
@@ -203,7 +203,7 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") {
// This could be empty if the JobProgressListener hasn't received information about the
// stage or if the stage information has been garbage collected
listener.stageIdToInfo.getOrElse(stageId,
- new StageInfo(stageId, 0, "Unknown", 0, Seq.empty, Seq.empty, "Unknown"))
+ new StageInfo(stageId, 0, "Unknown", 0, Seq.empty, Seq.empty, "Unknown", Seq.empty))
}
val activeStages = Buffer[StageInfo]()
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index ed3ab66e3b..13f5f84d06 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -396,13 +396,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
None
}
taskMetrics.foreach { m =>
- val oldMetrics = stageData.taskData.get(info.taskId).flatMap(_.taskMetrics)
+ val oldMetrics = stageData.taskData.get(info.taskId).flatMap(_.metrics)
updateAggregateMetrics(stageData, info.executorId, m, oldMetrics)
}
val taskData = stageData.taskData.getOrElseUpdate(info.taskId, new TaskUIData(info))
taskData.taskInfo = info
- taskData.taskMetrics = taskMetrics
+ taskData.metrics = taskMetrics
taskData.errorMessage = errorMessage
for (
@@ -506,9 +506,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
val metrics = TaskMetrics.fromAccumulatorUpdates(accumUpdates)
taskData.foreach { t =>
if (!t.taskInfo.finished) {
- updateAggregateMetrics(stageData, executorMetricsUpdate.execId, metrics, t.taskMetrics)
+ updateAggregateMetrics(stageData, executorMetricsUpdate.execId, metrics, t.metrics)
// Overwrite task metrics
- t.taskMetrics = Some(metrics)
+ t.metrics = Some(metrics)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
index 0d0e9b00d3..7b00b558d5 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
@@ -31,6 +31,8 @@ private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") {
def isFairScheduler: Boolean =
jobProgresslistener.schedulingMode == Some(SchedulingMode.FAIR)
+ def getSparkUser: String = parent.getSparkUser
+
attachPage(new AllJobsPage(this))
attachPage(new JobPage(this))
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 689ab7dd5e..8a44bbd9fc 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -330,7 +330,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
else taskTable.dataSource.slicedTaskIds
// Excludes tasks which failed and have incomplete metrics
- val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.taskMetrics.isDefined)
+ val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.metrics.isDefined)
val summaryTable: Option[Seq[Node]] =
if (validTasks.size == 0) {
@@ -348,8 +348,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
getDistributionQuantiles(data).map(d => <td>{Utils.bytesToString(d.toLong)}</td>)
}
- val deserializationTimes = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.executorDeserializeTime.toDouble
+ val deserializationTimes = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.executorDeserializeTime.toDouble
}
val deserializationQuantiles =
<td>
@@ -359,13 +359,13 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
</span>
</td> +: getFormattedTimeQuantiles(deserializationTimes)
- val serviceTimes = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.executorRunTime.toDouble
+ val serviceTimes = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.executorRunTime.toDouble
}
val serviceQuantiles = <td>Duration</td> +: getFormattedTimeQuantiles(serviceTimes)
- val gcTimes = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.jvmGCTime.toDouble
+ val gcTimes = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.jvmGCTime.toDouble
}
val gcQuantiles =
<td>
@@ -374,8 +374,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
</span>
</td> +: getFormattedTimeQuantiles(gcTimes)
- val serializationTimes = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.resultSerializationTime.toDouble
+ val serializationTimes = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.resultSerializationTime.toDouble
}
val serializationQuantiles =
<td>
@@ -385,8 +385,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
</span>
</td> +: getFormattedTimeQuantiles(serializationTimes)
- val gettingResultTimes = validTasks.map { case TaskUIData(info, _, _) =>
- getGettingResultTime(info, currentTime).toDouble
+ val gettingResultTimes = validTasks.map { taskUIData: TaskUIData =>
+ getGettingResultTime(taskUIData.taskInfo, currentTime).toDouble
}
val gettingResultQuantiles =
<td>
@@ -397,8 +397,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
</td> +:
getFormattedTimeQuantiles(gettingResultTimes)
- val peakExecutionMemory = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.peakExecutionMemory.toDouble
+ val peakExecutionMemory = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.peakExecutionMemory.toDouble
}
val peakExecutionMemoryQuantiles = {
<td>
@@ -412,8 +412,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
// The scheduler delay includes the network delay to send the task to the worker
// machine and to send back the result (but not the time to fetch the task result,
// if it needed to be fetched from the block manager on the worker).
- val schedulerDelays = validTasks.map { case TaskUIData(info, metrics, _) =>
- getSchedulerDelay(info, metrics.get, currentTime).toDouble
+ val schedulerDelays = validTasks.map { taskUIData: TaskUIData =>
+ getSchedulerDelay(taskUIData.taskInfo, taskUIData.metrics.get, currentTime).toDouble
}
val schedulerDelayTitle = <td><span data-toggle="tooltip"
title={ToolTips.SCHEDULER_DELAY} data-placement="right">Scheduler Delay</span></td>
@@ -427,30 +427,30 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
)
}
- val inputSizes = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.inputMetrics.map(_.bytesRead).getOrElse(0L).toDouble
+ val inputSizes = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.inputMetrics.map(_.bytesRead).getOrElse(0L).toDouble
}
- val inputRecords = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.inputMetrics.map(_.recordsRead).getOrElse(0L).toDouble
+ val inputRecords = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.inputMetrics.map(_.recordsRead).getOrElse(0L).toDouble
}
val inputQuantiles = <td>Input Size / Records</td> +:
getFormattedSizeQuantilesWithRecords(inputSizes, inputRecords)
- val outputSizes = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.outputMetrics.map(_.bytesWritten).getOrElse(0L).toDouble
+ val outputSizes = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.outputMetrics.map(_.bytesWritten).getOrElse(0L).toDouble
}
- val outputRecords = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.outputMetrics.map(_.recordsWritten).getOrElse(0L).toDouble
+ val outputRecords = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.outputMetrics.map(_.recordsWritten).getOrElse(0L).toDouble
}
val outputQuantiles = <td>Output Size / Records</td> +:
getFormattedSizeQuantilesWithRecords(outputSizes, outputRecords)
- val shuffleReadBlockedTimes = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.shuffleReadMetrics.map(_.fetchWaitTime).getOrElse(0L).toDouble
+ val shuffleReadBlockedTimes = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.shuffleReadMetrics.map(_.fetchWaitTime).getOrElse(0L).toDouble
}
val shuffleReadBlockedQuantiles =
<td>
@@ -461,11 +461,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
</td> +:
getFormattedTimeQuantiles(shuffleReadBlockedTimes)
- val shuffleReadTotalSizes = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.shuffleReadMetrics.map(_.totalBytesRead).getOrElse(0L).toDouble
+ val shuffleReadTotalSizes = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.shuffleReadMetrics.map(_.totalBytesRead).getOrElse(0L).toDouble
}
- val shuffleReadTotalRecords = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.shuffleReadMetrics.map(_.recordsRead).getOrElse(0L).toDouble
+ val shuffleReadTotalRecords = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.shuffleReadMetrics.map(_.recordsRead).getOrElse(0L).toDouble
}
val shuffleReadTotalQuantiles =
<td>
@@ -476,8 +476,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
</td> +:
getFormattedSizeQuantilesWithRecords(shuffleReadTotalSizes, shuffleReadTotalRecords)
- val shuffleReadRemoteSizes = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble
+ val shuffleReadRemoteSizes = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble
}
val shuffleReadRemoteQuantiles =
<td>
@@ -488,25 +488,25 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
</td> +:
getFormattedSizeQuantiles(shuffleReadRemoteSizes)
- val shuffleWriteSizes = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.shuffleWriteMetrics.map(_.bytesWritten).getOrElse(0L).toDouble
+ val shuffleWriteSizes = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.shuffleWriteMetrics.map(_.bytesWritten).getOrElse(0L).toDouble
}
- val shuffleWriteRecords = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.shuffleWriteMetrics.map(_.recordsWritten).getOrElse(0L).toDouble
+ val shuffleWriteRecords = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.shuffleWriteMetrics.map(_.recordsWritten).getOrElse(0L).toDouble
}
val shuffleWriteQuantiles = <td>Shuffle Write Size / Records</td> +:
getFormattedSizeQuantilesWithRecords(shuffleWriteSizes, shuffleWriteRecords)
- val memoryBytesSpilledSizes = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.memoryBytesSpilled.toDouble
+ val memoryBytesSpilledSizes = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.memoryBytesSpilled.toDouble
}
val memoryBytesSpilledQuantiles = <td>Shuffle spill (memory)</td> +:
getFormattedSizeQuantiles(memoryBytesSpilledSizes)
- val diskBytesSpilledSizes = validTasks.map { case TaskUIData(_, metrics, _) =>
- metrics.get.diskBytesSpilled.toDouble
+ val diskBytesSpilledSizes = validTasks.map { taskUIData: TaskUIData =>
+ taskUIData.metrics.get.diskBytesSpilled.toDouble
}
val diskBytesSpilledQuantiles = <td>Shuffle spill (disk)</td> +:
getFormattedSizeQuantiles(diskBytesSpilledSizes)
@@ -601,7 +601,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
def toProportion(time: Long) = time.toDouble / totalExecutionTime * 100
- val metricsOpt = taskUIData.taskMetrics
+ val metricsOpt = taskUIData.metrics
val shuffleReadTime =
metricsOpt.flatMap(_.shuffleReadMetrics.map(_.fetchWaitTime)).getOrElse(0L)
val shuffleReadTimeProportion = toProportion(shuffleReadTime)
@@ -868,7 +868,8 @@ private[ui] class TaskDataSource(
def slicedTaskIds: Set[Long] = _slicedTaskIds
private def taskRow(taskData: TaskUIData): TaskTableRowData = {
- val TaskUIData(info, metrics, errorMessage) = taskData
+ val info = taskData.taskInfo
+ val metrics = taskData.metrics
val duration = if (info.status == "RUNNING") info.timeRunning(currentTime)
else metrics.map(_.executorRunTime).getOrElse(1L)
val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration)
@@ -1014,7 +1015,7 @@ private[ui] class TaskDataSource(
shuffleRead,
shuffleWrite,
bytesSpilled,
- errorMessage.getOrElse(""))
+ taskData.errorMessage.getOrElse(""))
}
/**
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
index 78165d7b74..b454ef1b20 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
@@ -105,12 +105,12 @@ private[spark] object UIData {
/**
* These are kept mutable and reused throughout a task's lifetime to avoid excessive reallocation.
*/
- case class TaskUIData(
+ class TaskUIData(
var taskInfo: TaskInfo,
- var taskMetrics: Option[TaskMetrics] = None,
+ var metrics: Option[TaskMetrics] = None,
var errorMessage: Option[String] = None)
- case class ExecutorUIData(
+ class ExecutorUIData(
val startTime: Long,
var finishTime: Option[Long] = None,
var finishReason: Option[String] = None)
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
index 8f75b586e1..50095831b4 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
@@ -57,17 +57,6 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Bloc
StorageUtils.updateRddInfo(rddInfosToUpdate, activeStorageStatusList)
}
- /**
- * Assumes the storage status list is fully up-to-date. This implies the corresponding
- * StorageStatusSparkListener must process the SparkListenerTaskEnd event before this listener.
- */
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
- val metrics = taskEnd.taskMetrics
- if (metrics != null && metrics.updatedBlockStatuses.nonEmpty) {
- updateRDDInfo(metrics.updatedBlockStatuses)
- }
- }
-
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized {
val rddInfos = stageSubmitted.stageInfo.rddInfos
rddInfos.foreach { info => _rddInfoMap.getOrElseUpdate(info.id, info) }
@@ -84,4 +73,14 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Bloc
override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = synchronized {
_rddInfoMap.remove(unpersistRDD.rddId)
}
+
+ override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = {
+ super.onBlockUpdated(blockUpdated)
+ val blockId = blockUpdated.blockUpdatedInfo.blockId
+ val storageLevel = blockUpdated.blockUpdatedInfo.storageLevel
+ val memSize = blockUpdated.blockUpdatedInfo.memSize
+ val diskSize = blockUpdated.blockUpdatedInfo.diskSize
+ val blockStatus = BlockStatus(storageLevel, memSize, diskSize)
+ updateRDDInfo(Seq((blockId, blockStatus)))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala b/core/src/main/scala/org/apache/spark/util/CausedBy.scala
index 8b05f9e33d..73df446d98 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala
+++ b/core/src/main/scala/org/apache/spark/util/CausedBy.scala
@@ -14,25 +14,23 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.sql.catalyst.parser
-import org.apache.spark.SparkFunSuite
+package org.apache.spark.util
-class ASTNodeSuite extends SparkFunSuite {
- test("SPARK-13157 - remainder must return all input chars") {
- val inputs = Seq(
- ("add jar", "file:///tmp/ab/TestUDTF.jar"),
- ("add jar", "file:///tmp/a@b/TestUDTF.jar"),
- ("add jar", "c:\\windows32\\TestUDTF.jar"),
- ("add jar", "some \nbad\t\tfile\r\n.\njar"),
- ("ADD JAR", "@*#&@(!#@$^*!@^@#(*!@#"),
- ("SET", "foo=bar"),
- ("SET", "foo*)(@#^*@&!#^=bar")
- )
- inputs.foreach {
- case (command, arguments) =>
- val node = ParseDriver.parsePlan(s"$command $arguments", null)
- assert(node.remainder === arguments)
- }
+/**
+ * Extractor Object for pulling out the root cause of an error.
+ * If the error contains no cause, it will return the error itself.
+ *
+ * Usage:
+ * try {
+ * ...
+ * } catch {
+ * case CausedBy(ex: CommitDeniedException) => ...
+ * }
+ */
+private[spark] object CausedBy {
+
+ def unapply(e: Throwable): Option[Throwable] = {
+ Option(e.getCause).flatMap(cause => unapply(cause)).orElse(Some(e))
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
index f4772a9803..489688cb08 100644
--- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
@@ -19,7 +19,8 @@ package org.apache.spark.util
import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
-import scala.collection.mutable.{Map, Set}
+import scala.collection.mutable.{Map, Set, Stack}
+import scala.language.existentials
import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor, Type}
import org.apache.xbean.asm5.Opcodes._
@@ -77,15 +78,14 @@ private[spark] object ClosureCleaner extends Logging {
*/
private def getInnerClosureClasses(obj: AnyRef): List[Class[_]] = {
val seen = Set[Class[_]](obj.getClass)
- var stack = List[Class[_]](obj.getClass)
+ val stack = Stack[Class[_]](obj.getClass)
while (!stack.isEmpty) {
- val cr = getClassReader(stack.head)
- stack = stack.tail
+ val cr = getClassReader(stack.pop())
val set = Set[Class[_]]()
cr.accept(new InnerClosureFinder(set), 0)
for (cls <- set -- seen) {
seen += cls
- stack = cls :: stack
+ stack.push(cls)
}
}
(seen - obj.getClass).toList
@@ -218,16 +218,24 @@ private[spark] object ClosureCleaner extends Logging {
// Note that all outer objects but the outermost one (first one in this list) must be closures
var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip outerObjects).reverse
var parent: AnyRef = null
- if (outerPairs.size > 0 && !isClosure(outerPairs.head._1)) {
- // The closure is ultimately nested inside a class; keep the object of that
- // class without cloning it since we don't want to clone the user's objects.
- // Note that we still need to keep around the outermost object itself because
- // we need it to clone its child closure later (see below).
- logDebug(s" + outermost object is not a closure, so do not clone it: ${outerPairs.head}")
- parent = outerPairs.head._2 // e.g. SparkContext
- outerPairs = outerPairs.tail
- } else if (outerPairs.size > 0) {
- logDebug(s" + outermost object is a closure, so we just keep it: ${outerPairs.head}")
+ if (outerPairs.size > 0) {
+ val (outermostClass, outermostObject) = outerPairs.head
+ if (isClosure(outermostClass)) {
+ logDebug(s" + outermost object is a closure, so we clone it: ${outerPairs.head}")
+ } else if (outermostClass.getName.startsWith("$line")) {
+ // SPARK-14558: if the outermost object is a REPL line object, we should clone and clean it
+ // as it may carray a lot of unnecessary information, e.g. hadoop conf, spark conf, etc.
+ logDebug(s" + outermost object is a REPL line object, so we clone it: ${outerPairs.head}")
+ } else {
+ // The closure is ultimately nested inside a class; keep the object of that
+ // class without cloning it since we don't want to clone the user's objects.
+ // Note that we still need to keep around the outermost object itself because
+ // we need it to clone its child closure later (see below).
+ logDebug(" + outermost object is not a closure or REPL line object, so do not clone it: " +
+ outerPairs.head)
+ parent = outermostObject // e.g. SparkContext
+ outerPairs = outerPairs.tail
+ }
} else {
logDebug(" + there are no enclosing objects!")
}
diff --git a/core/src/main/scala/org/apache/spark/util/EventLoop.scala b/core/src/main/scala/org/apache/spark/util/EventLoop.scala
index 153025cef2..3ea9139e11 100644
--- a/core/src/main/scala/org/apache/spark/util/EventLoop.scala
+++ b/core/src/main/scala/org/apache/spark/util/EventLoop.scala
@@ -47,13 +47,12 @@ private[spark] abstract class EventLoop[E](name: String) extends Logging {
try {
onReceive(event)
} catch {
- case NonFatal(e) => {
+ case NonFatal(e) =>
try {
onError(e)
} catch {
case NonFatal(e) => logError("Unexpected error in " + name, e)
}
- }
}
}
} catch {
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index 09d955300a..558767e36f 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -578,7 +578,9 @@ private[spark] object JsonProtocol {
// The "Stage Infos" field was added in Spark 1.2.0
val stageInfos = Utils.jsonOption(json \ "Stage Infos")
.map(_.extract[Seq[JValue]].map(stageInfoFromJson)).getOrElse {
- stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown"))
+ stageIds.map { id =>
+ new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown", Seq.empty)
+ }
}
SparkListenerJobStart(jobId, submissionTime, stageInfos, properties)
}
@@ -686,7 +688,7 @@ private[spark] object JsonProtocol {
}
val stageInfo = new StageInfo(
- stageId, attemptId, stageName, numTasks, rddInfos, parentIds, details)
+ stageId, attemptId, stageName, numTasks, rddInfos, parentIds, details, Seq.empty)
stageInfo.submissionTime = submissionTime
stageInfo.completionTime = completionTime
stageInfo.failureReason = failureReason
@@ -811,8 +813,8 @@ private[spark] object JsonProtocol {
Utils.jsonOption(json \ "Input Metrics").foreach { inJson =>
val readMethod = DataReadMethod.withName((inJson \ "Data Read Method").extract[String])
val inputMetrics = metrics.registerInputMetrics(readMethod)
- inputMetrics.incBytesReadInternal((inJson \ "Bytes Read").extract[Long])
- inputMetrics.incRecordsReadInternal((inJson \ "Records Read").extractOpt[Long].getOrElse(0L))
+ inputMetrics.incBytesRead((inJson \ "Bytes Read").extract[Long])
+ inputMetrics.incRecordsRead((inJson \ "Records Read").extractOpt[Long].getOrElse(0L))
}
// Updated blocks
diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
index 3f627a0145..6861a75612 100644
--- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
+++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
@@ -151,13 +151,12 @@ object SizeEstimator extends Logging {
// TODO: We could use reflection on the VMOption returned ?
getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true")
} catch {
- case e: Exception => {
+ case e: Exception =>
// Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB
val guess = Runtime.getRuntime.maxMemory < (32L*1024*1024*1024)
val guessInWords = if (guess) "yes" else "not"
logWarning("Failed to check whether UseCompressedOops is set; assuming " + guessInWords)
return guess
- }
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala
new file mode 100644
index 0000000000..4dcf95177a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import javax.annotation.concurrent.GuardedBy
+
+/**
+ * A special Thread that provides "runUninterruptibly" to allow running codes without being
+ * interrupted by `Thread.interrupt()`. If `Thread.interrupt()` is called during runUninterruptibly
+ * is running, it won't set the interrupted status. Instead, setting the interrupted status will be
+ * deferred until it's returning from "runUninterruptibly".
+ *
+ * Note: "runUninterruptibly" should be called only in `this` thread.
+ */
+private[spark] class UninterruptibleThread(name: String) extends Thread(name) {
+
+ /** A monitor to protect "uninterruptible" and "interrupted" */
+ private val uninterruptibleLock = new Object
+
+ /**
+ * Indicates if `this` thread are in the uninterruptible status. If so, interrupting
+ * "this" will be deferred until `this` enters into the interruptible status.
+ */
+ @GuardedBy("uninterruptibleLock")
+ private var uninterruptible = false
+
+ /**
+ * Indicates if we should interrupt `this` when we are leaving the uninterruptible zone.
+ */
+ @GuardedBy("uninterruptibleLock")
+ private var shouldInterruptThread = false
+
+ /**
+ * Run `f` uninterruptibly in `this` thread. The thread won't be interrupted before returning
+ * from `f`.
+ *
+ * If this method finds that `interrupt` is called before calling `f` and it's not inside another
+ * `runUninterruptibly`, it will throw `InterruptedException`.
+ *
+ * Note: this method should be called only in `this` thread.
+ */
+ def runUninterruptibly[T](f: => T): T = {
+ if (Thread.currentThread() != this) {
+ throw new IllegalStateException(s"Call runUninterruptibly in a wrong thread. " +
+ s"Expected: $this but was ${Thread.currentThread()}")
+ }
+
+ if (uninterruptibleLock.synchronized { uninterruptible }) {
+ // We are already in the uninterruptible status. So just run "f" and return
+ return f
+ }
+
+ uninterruptibleLock.synchronized {
+ // Clear the interrupted status if it's set.
+ if (Thread.interrupted() || shouldInterruptThread) {
+ shouldInterruptThread = false
+ // Since it's interrupted, we don't need to run `f` which may be a long computation.
+ // Throw InterruptedException as we don't have a T to return.
+ throw new InterruptedException()
+ }
+ uninterruptible = true
+ }
+ try {
+ f
+ } finally {
+ uninterruptibleLock.synchronized {
+ uninterruptible = false
+ if (shouldInterruptThread) {
+ // Recover the interrupted status
+ super.interrupt()
+ shouldInterruptThread = false
+ }
+ }
+ }
+ }
+
+ /**
+ * Tests whether `interrupt()` has been called.
+ */
+ override def isInterrupted: Boolean = {
+ super.isInterrupted || uninterruptibleLock.synchronized { shouldInterruptThread }
+ }
+
+ /**
+ * Interrupt `this` thread if possible. If `this` is in the uninterruptible status, it won't be
+ * interrupted until it enters into the interruptible status.
+ */
+ override def interrupt(): Unit = {
+ uninterruptibleLock.synchronized {
+ if (uninterruptible) {
+ shouldInterruptThread = true
+ } else {
+ super.interrupt()
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 73768ff4c8..78e164cff7 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -256,10 +256,11 @@ private[spark] object Utils extends Logging {
dir
}
- /** Copy all data from an InputStream to an OutputStream. NIO way of file stream to file stream
- * copying is disabled by default unless explicitly set transferToEnabled as true,
- * the parameter transferToEnabled should be configured by spark.file.transferTo = [true|false].
- */
+ /**
+ * Copy all data from an InputStream to an OutputStream. NIO way of file stream to file stream
+ * copying is disabled by default unless explicitly set transferToEnabled as true,
+ * the parameter transferToEnabled should be configured by spark.file.transferTo = [true|false].
+ */
def copyStream(in: InputStream,
out: OutputStream,
closeStreams: Boolean = false,
@@ -1120,9 +1121,9 @@ private[spark] object Utils extends Logging {
extraEnvironment: Map[String, String] = Map.empty,
redirectStderr: Boolean = true): String = {
val process = executeCommand(command, workingDir, extraEnvironment, redirectStderr)
- val output = new StringBuffer
+ val output = new StringBuilder
val threadName = "read stdout for " + command(0)
- def appendToOutput(s: String): Unit = output.append(s)
+ def appendToOutput(s: String): Unit = output.append(s).append("\n")
val stdoutThread = processStreamByLine(threadName, process.getInputStream, appendToOutput)
val exitCode = process.waitFor()
stdoutThread.join() // Wait for it to finish reading output
@@ -1259,26 +1260,35 @@ private[spark] object Utils extends Logging {
}
/**
- * Execute a block of code, call the failure callbacks before finally block if there is any
- * exceptions happen. But if exceptions happen in the finally block, do not suppress the original
- * exception.
+ * Execute a block of code and call the failure callbacks in the catch block. If exceptions occur
+ * in either the catch or the finally block, they are appended to the list of suppressed
+ * exceptions in original exception which is then rethrown.
*
- * This is primarily an issue with `finally { out.close() }` blocks, where
- * close needs to be called to clean up `out`, but if an exception happened
- * in `out.write`, it's likely `out` may be corrupted and `out.close` will
+ * This is primarily an issue with `catch { abort() }` or `finally { out.close() }` blocks,
+ * where the abort/close needs to be called to clean up `out`, but if an exception happened
+ * in `out.write`, it's likely `out` may be corrupted and `abort` or `out.close` will
* fail as well. This would then suppress the original/likely more meaningful
* exception from the original `out.write` call.
*/
- def tryWithSafeFinallyAndFailureCallbacks[T](block: => T)(finallyBlock: => Unit): T = {
+ def tryWithSafeFinallyAndFailureCallbacks[T](block: => T)
+ (catchBlock: => Unit = (), finallyBlock: => Unit = ()): T = {
var originalThrowable: Throwable = null
try {
block
} catch {
- case t: Throwable =>
+ case cause: Throwable =>
// Purposefully not using NonFatal, because even fatal exceptions
// we don't want to have our finallyBlock suppress
- originalThrowable = t
- TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(t)
+ originalThrowable = cause
+ try {
+ logError("Aborting task", originalThrowable)
+ TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(originalThrowable)
+ catchBlock
+ } catch {
+ case t: Throwable =>
+ originalThrowable.addSuppressed(t)
+ logWarning(s"Suppressing exception in catch: " + t.getMessage, t)
+ }
throw originalThrowable
} finally {
try {
@@ -1564,9 +1574,11 @@ private[spark] object Utils extends Logging {
else -1
}
- /** Returns the system properties map that is thread-safe to iterator over. It gets the
- * properties which have been set explicitly, as well as those for which only a default value
- * has been defined. */
+ /**
+ * Returns the system properties map that is thread-safe to iterator over. It gets the
+ * properties which have been set explicitly, as well as those for which only a default value
+ * has been defined.
+ */
def getSystemProperties: Map[String, String] = {
System.getProperties.stringPropertyNames().asScala
.map(key => (key, System.getProperty(key))).toMap
diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
index c643c4b63c..fb4706e78d 100644
--- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
@@ -41,6 +41,8 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
require(chunks.forall(_.limit() > 0), "chunks must be non-empty")
require(chunks.forall(_.position() == 0), "chunks' positions must be 0")
+ private[this] var disposed: Boolean = false
+
/**
* This size of this buffer, in bytes.
*/
@@ -117,11 +119,12 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
/**
* Make a copy of this ChunkedByteBuffer, copying all of the backing data into new buffers.
* The new buffer will share no resources with the original buffer.
+ *
+ * @param allocator a method for allocating byte buffers
*/
- def copy(): ChunkedByteBuffer = {
+ def copy(allocator: Int => ByteBuffer): ChunkedByteBuffer = {
val copiedChunks = getChunks().map { chunk =>
- // TODO: accept an allocator in this copy method to integrate with mem. accounting systems
- val newChunk = ByteBuffer.allocate(chunk.limit())
+ val newChunk = allocator(chunk.limit())
newChunk.put(chunk)
newChunk.flip()
newChunk
@@ -136,7 +139,10 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
* unfortunately no standard API to do this.
*/
def dispose(): Unit = {
- chunks.foreach(StorageUtils.dispose)
+ if (!disposed) {
+ chunks.foreach(StorageUtils.dispose)
+ disposed = true
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala
index 16fe3be303..67b50d1e70 100644
--- a/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala
+++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala
@@ -18,19 +18,25 @@
package org.apache.spark.util.io
import java.io.OutputStream
+import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.storage.StorageUtils
/**
* An OutputStream that writes to fixed-size chunks of byte arrays.
*
* @param chunkSize size of each chunk, in bytes.
*/
-private[spark]
-class ByteArrayChunkOutputStream(chunkSize: Int) extends OutputStream {
+private[spark] class ChunkedByteBufferOutputStream(
+ chunkSize: Int,
+ allocator: Int => ByteBuffer)
+ extends OutputStream {
- private[this] val chunks = new ArrayBuffer[Array[Byte]]
+ private[this] var toChunkedByteBufferWasCalled = false
+
+ private val chunks = new ArrayBuffer[ByteBuffer]
/** Index of the last chunk. Starting with -1 when the chunks array is empty. */
private[this] var lastChunkIndex = -1
@@ -48,7 +54,7 @@ class ByteArrayChunkOutputStream(chunkSize: Int) extends OutputStream {
override def write(b: Int): Unit = {
allocateNewChunkIfNeeded()
- chunks(lastChunkIndex)(position) = b.toByte
+ chunks(lastChunkIndex).put(b.toByte)
position += 1
_size += 1
}
@@ -58,7 +64,7 @@ class ByteArrayChunkOutputStream(chunkSize: Int) extends OutputStream {
while (written < len) {
allocateNewChunkIfNeeded()
val thisBatch = math.min(chunkSize - position, len - written)
- System.arraycopy(bytes, written + off, chunks(lastChunkIndex), position, thisBatch)
+ chunks(lastChunkIndex).put(bytes, written + off, thisBatch)
written += thisBatch
position += thisBatch
}
@@ -67,33 +73,41 @@ class ByteArrayChunkOutputStream(chunkSize: Int) extends OutputStream {
@inline
private def allocateNewChunkIfNeeded(): Unit = {
+ require(!toChunkedByteBufferWasCalled, "cannot write after toChunkedByteBuffer() is called")
if (position == chunkSize) {
- chunks += new Array[Byte](chunkSize)
+ chunks += allocator(chunkSize)
lastChunkIndex += 1
position = 0
}
}
- def toArrays: Array[Array[Byte]] = {
+ def toChunkedByteBuffer: ChunkedByteBuffer = {
+ require(!toChunkedByteBufferWasCalled, "toChunkedByteBuffer() can only be called once")
+ toChunkedByteBufferWasCalled = true
if (lastChunkIndex == -1) {
- new Array[Array[Byte]](0)
+ new ChunkedByteBuffer(Array.empty[ByteBuffer])
} else {
// Copy the first n-1 chunks to the output, and then create an array that fits the last chunk.
// An alternative would have been returning an array of ByteBuffers, with the last buffer
// bounded to only the last chunk's position. However, given our use case in Spark (to put
// the chunks in block manager), only limiting the view bound of the buffer would still
// require the block manager to store the whole chunk.
- val ret = new Array[Array[Byte]](chunks.size)
+ val ret = new Array[ByteBuffer](chunks.size)
for (i <- 0 until chunks.size - 1) {
ret(i) = chunks(i)
+ ret(i).flip()
}
if (position == chunkSize) {
ret(lastChunkIndex) = chunks(lastChunkIndex)
+ ret(lastChunkIndex).flip()
} else {
- ret(lastChunkIndex) = new Array[Byte](position)
- System.arraycopy(chunks(lastChunkIndex), 0, ret(lastChunkIndex), 0, position)
+ ret(lastChunkIndex) = allocator(position)
+ chunks(lastChunkIndex).flip()
+ ret(lastChunkIndex).put(chunks(lastChunkIndex))
+ ret(lastChunkIndex).flip()
+ StorageUtils.dispose(chunks(lastChunkIndex))
}
- ret
+ new ChunkedByteBuffer(ret)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala
index b34880d3a7..6e80db2f51 100644
--- a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala
+++ b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala
@@ -32,10 +32,10 @@ private[spark] trait RollingPolicy {
def shouldRollover(bytesToBeWritten: Long): Boolean
/** Notify that rollover has occurred */
- def rolledOver()
+ def rolledOver(): Unit
/** Notify that bytes have been written */
- def bytesWritten(bytes: Long)
+ def bytesWritten(bytes: Long): Unit
/** Get the desired name of the rollover file */
def generateRolledOverFileSuffix(): String
diff --git a/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala b/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala
index 70f3dd62b9..41f28f6e51 100644
--- a/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala
@@ -26,5 +26,5 @@ import org.apache.spark.annotation.DeveloperApi
@DeveloperApi
trait Pseudorandom {
/** Set random seed. */
- def setSeed(seed: Long)
+ def setSeed(seed: Long): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
index 3c61528ab5..8c67364ef1 100644
--- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
@@ -39,7 +39,14 @@ import org.apache.spark.annotation.DeveloperApi
trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable {
/** take a random sample */
- def sample(items: Iterator[T]): Iterator[U]
+ def sample(items: Iterator[T]): Iterator[U] =
+ items.filter(_ => sample > 0).asInstanceOf[Iterator[U]]
+
+ /**
+ * Whether to sample the next item or not.
+ * Return how many times the next item will be sampled. Return 0 if it is not sampled.
+ */
+ def sample(): Int
/** return a copy of the RandomSampler object */
override def clone: RandomSampler[T, U] =
@@ -107,21 +114,13 @@ class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = fals
override def setSeed(seed: Long): Unit = rng.setSeed(seed)
- override def sample(items: Iterator[T]): Iterator[T] = {
+ override def sample(): Int = {
if (ub - lb <= 0.0) {
- if (complement) items else Iterator.empty
+ if (complement) 1 else 0
} else {
- if (complement) {
- items.filter { item => {
- val x = rng.nextDouble()
- (x < lb) || (x >= ub)
- }}
- } else {
- items.filter { item => {
- val x = rng.nextDouble()
- (x >= lb) && (x < ub)
- }}
- }
+ val x = rng.nextDouble()
+ val n = if ((x >= lb) && (x < ub)) 1 else 0
+ if (complement) 1 - n else n
}
}
@@ -155,15 +154,22 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T
override def setSeed(seed: Long): Unit = rng.setSeed(seed)
- override def sample(items: Iterator[T]): Iterator[T] = {
+ private lazy val gapSampling: GapSampling =
+ new GapSampling(fraction, rng, RandomSampler.rngEpsilon)
+
+ override def sample(): Int = {
if (fraction <= 0.0) {
- Iterator.empty
+ 0
} else if (fraction >= 1.0) {
- items
+ 1
} else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
- new GapSamplingIterator(items, fraction, rng, RandomSampler.rngEpsilon)
+ gapSampling.sample()
} else {
- items.filter { _ => rng.nextDouble() <= fraction }
+ if (rng.nextDouble() <= fraction) {
+ 1
+ } else {
+ 0
+ }
}
}
@@ -180,7 +186,7 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T
* @tparam T item type
*/
@DeveloperApi
-class PoissonSampler[T: ClassTag](
+class PoissonSampler[T](
fraction: Double,
useGapSamplingIfPossible: Boolean) extends RandomSampler[T, T] {
@@ -201,15 +207,29 @@ class PoissonSampler[T: ClassTag](
rngGap.setSeed(seed)
}
- override def sample(items: Iterator[T]): Iterator[T] = {
+ private lazy val gapSamplingReplacement =
+ new GapSamplingReplacement(fraction, rngGap, RandomSampler.rngEpsilon)
+
+ override def sample(): Int = {
if (fraction <= 0.0) {
- Iterator.empty
+ 0
} else if (useGapSamplingIfPossible &&
fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
- new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon)
+ gapSamplingReplacement.sample()
} else {
+ rng.sample()
+ }
+ }
+
+ override def sample(items: Iterator[T]): Iterator[T] = {
+ if (fraction <= 0.0) {
+ Iterator.empty
+ } else {
+ val useGapSampling = useGapSamplingIfPossible &&
+ fraction <= RandomSampler.defaultMaxGapSamplingFraction
+
items.flatMap { item =>
- val count = rng.sample()
+ val count = if (useGapSampling) gapSamplingReplacement.sample() else rng.sample()
if (count == 0) Iterator.empty else Iterator.fill(count)(item)
}
}
@@ -220,50 +240,36 @@ class PoissonSampler[T: ClassTag](
private[spark]
-class GapSamplingIterator[T: ClassTag](
- var data: Iterator[T],
+class GapSampling(
f: Double,
rng: Random = RandomSampler.newDefaultRNG,
- epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] {
+ epsilon: Double = RandomSampler.rngEpsilon) extends Serializable {
require(f > 0.0 && f < 1.0, s"Sampling fraction ($f) must reside on open interval (0, 1)")
require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0")
- /** implement efficient linear-sequence drop until Scala includes fix for jira SI-8835. */
- private val iterDrop: Int => Unit = {
- val arrayClass = Array.empty[T].iterator.getClass
- val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass
- data.getClass match {
- case `arrayClass` =>
- (n: Int) => { data = data.drop(n) }
- case `arrayBufferClass` =>
- (n: Int) => { data = data.drop(n) }
- case _ =>
- (n: Int) => {
- var j = 0
- while (j < n && data.hasNext) {
- data.next()
- j += 1
- }
- }
- }
- }
-
- override def hasNext: Boolean = data.hasNext
+ private val lnq = math.log1p(-f)
- override def next(): T = {
- val r = data.next()
- advance()
- r
+ /** Return 1 if the next item should be sampled. Otherwise, return 0. */
+ def sample(): Int = {
+ if (countForDropping > 0) {
+ countForDropping -= 1
+ 0
+ } else {
+ advance()
+ 1
+ }
}
- private val lnq = math.log1p(-f)
+ private var countForDropping: Int = 0
- /** skip elements that won't be sampled, according to geometric dist P(k) = (f)(1-f)^k. */
+ /**
+ * Decide the number of elements that won't be sampled,
+ * according to geometric dist P(k) = (f)(1-f)^k.
+ */
private def advance(): Unit = {
val u = math.max(rng.nextDouble(), epsilon)
- val k = (math.log(u) / lnq).toInt
- iterDrop(k)
+ countForDropping = (math.log(u) / lnq).toInt
}
/** advance to first sample as part of object construction. */
@@ -273,73 +279,24 @@ class GapSamplingIterator[T: ClassTag](
// work reliably.
}
+
private[spark]
-class GapSamplingReplacementIterator[T: ClassTag](
- var data: Iterator[T],
- f: Double,
- rng: Random = RandomSampler.newDefaultRNG,
- epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] {
+class GapSamplingReplacement(
+ val f: Double,
+ val rng: Random = RandomSampler.newDefaultRNG,
+ epsilon: Double = RandomSampler.rngEpsilon) extends Serializable {
require(f > 0.0, s"Sampling fraction ($f) must be > 0")
require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0")
- /** implement efficient linear-sequence drop until scala includes fix for jira SI-8835. */
- private val iterDrop: Int => Unit = {
- val arrayClass = Array.empty[T].iterator.getClass
- val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass
- data.getClass match {
- case `arrayClass` =>
- (n: Int) => { data = data.drop(n) }
- case `arrayBufferClass` =>
- (n: Int) => { data = data.drop(n) }
- case _ =>
- (n: Int) => {
- var j = 0
- while (j < n && data.hasNext) {
- data.next()
- j += 1
- }
- }
- }
- }
-
- /** current sampling value, and its replication factor, as we are sampling with replacement. */
- private var v: T = _
- private var rep: Int = 0
-
- override def hasNext: Boolean = data.hasNext || rep > 0
-
- override def next(): T = {
- val r = v
- rep -= 1
- if (rep <= 0) advance()
- r
- }
-
- /**
- * Skip elements with replication factor zero (i.e. elements that won't be sampled).
- * Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is
- * q is the probability of Poisson(0; f)
- */
- private def advance(): Unit = {
- val u = math.max(rng.nextDouble(), epsilon)
- val k = (math.log(u) / (-f)).toInt
- iterDrop(k)
- // set the value and replication factor for the next value
- if (data.hasNext) {
- v = data.next()
- rep = poissonGE1
- }
- }
-
- private val q = math.exp(-f)
+ protected val q = math.exp(-f)
/**
* Sample from Poisson distribution, conditioned such that the sampled value is >= 1.
* This is an adaptation from the algorithm for Generating Poisson distributed random variables:
* http://en.wikipedia.org/wiki/Poisson_distribution
*/
- private def poissonGE1: Int = {
+ protected def poissonGE1: Int = {
// simulate that the standard poisson sampling
// gave us at least one iteration, for a sample of >= 1
var pp = q + ((1.0 - q) * rng.nextDouble())
@@ -353,6 +310,28 @@ class GapSamplingReplacementIterator[T: ClassTag](
}
r
}
+ private var countForDropping: Int = 0
+
+ def sample(): Int = {
+ if (countForDropping > 0) {
+ countForDropping -= 1
+ 0
+ } else {
+ val r = poissonGE1
+ advance()
+ r
+ }
+ }
+
+ /**
+ * Skip elements with replication factor zero (i.e. elements that won't be sampled).
+ * Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is
+ * q is the probability of Poisson(0; f)
+ */
+ private def advance(): Unit = {
+ val u = math.max(rng.nextDouble(), epsilon)
+ countForDropping = (math.log(u) / (-f)).toInt
+ }
/** advance to first sample as part of object construction. */
advance()
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 44733dcdaf..30750b1bf1 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -170,11 +170,11 @@ public class UnsafeShuffleWriterSuite {
private UnsafeShuffleWriter<Object, Object> createWriter(
boolean transferToEnabled) throws IOException {
conf.set("spark.file.transferTo", String.valueOf(transferToEnabled));
- return new UnsafeShuffleWriter<Object, Object>(
+ return new UnsafeShuffleWriter<>(
blockManager,
shuffleBlockResolver,
taskMemoryManager,
- new SerializedShuffleHandle<Object, Object>(0, 1, shuffleDep),
+ new SerializedShuffleHandle<>(0, 1, shuffleDep),
0, // map id
taskContext,
conf
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 449fb45c30..84b82f5a47 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -182,7 +182,7 @@ public abstract class AbstractBytesToBytesMapSuite {
public void emptyMap() {
BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, PAGE_SIZE_BYTES);
try {
- Assert.assertEquals(0, map.numElements());
+ Assert.assertEquals(0, map.numKeys());
final int keyLengthInWords = 10;
final int keyLengthInBytes = keyLengthInWords * 8;
final byte[] key = getRandomByteArray(keyLengthInWords);
@@ -204,7 +204,7 @@ public abstract class AbstractBytesToBytesMapSuite {
final BytesToBytesMap.Location loc =
map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes);
Assert.assertFalse(loc.isDefined());
- Assert.assertTrue(loc.putNewKey(
+ Assert.assertTrue(loc.append(
keyData,
Platform.BYTE_ARRAY_OFFSET,
recordLengthBytes,
@@ -232,7 +232,7 @@ public abstract class AbstractBytesToBytesMapSuite {
getByteArray(loc.getValueBase(), loc.getValueOffset(), recordLengthBytes));
try {
- Assert.assertTrue(loc.putNewKey(
+ Assert.assertTrue(loc.append(
keyData,
Platform.BYTE_ARRAY_OFFSET,
recordLengthBytes,
@@ -260,7 +260,7 @@ public abstract class AbstractBytesToBytesMapSuite {
Assert.assertFalse(loc.isDefined());
// Ensure that we store some zero-length keys
if (i % 5 == 0) {
- Assert.assertTrue(loc.putNewKey(
+ Assert.assertTrue(loc.append(
null,
Platform.LONG_ARRAY_OFFSET,
0,
@@ -269,7 +269,7 @@ public abstract class AbstractBytesToBytesMapSuite {
8
));
} else {
- Assert.assertTrue(loc.putNewKey(
+ Assert.assertTrue(loc.append(
value,
Platform.LONG_ARRAY_OFFSET,
8,
@@ -349,7 +349,7 @@ public abstract class AbstractBytesToBytesMapSuite {
KEY_LENGTH
);
Assert.assertFalse(loc.isDefined());
- Assert.assertTrue(loc.putNewKey(
+ Assert.assertTrue(loc.append(
key,
Platform.LONG_ARRAY_OFFSET,
KEY_LENGTH,
@@ -417,7 +417,7 @@ public abstract class AbstractBytesToBytesMapSuite {
key.length
);
Assert.assertFalse(loc.isDefined());
- Assert.assertTrue(loc.putNewKey(
+ Assert.assertTrue(loc.append(
key,
Platform.BYTE_ARRAY_OFFSET,
key.length,
@@ -471,7 +471,7 @@ public abstract class AbstractBytesToBytesMapSuite {
key.length
);
Assert.assertFalse(loc.isDefined());
- Assert.assertTrue(loc.putNewKey(
+ Assert.assertTrue(loc.append(
key,
Platform.BYTE_ARRAY_OFFSET,
key.length,
@@ -514,7 +514,7 @@ public abstract class AbstractBytesToBytesMapSuite {
final BytesToBytesMap.Location loc =
map.lookup(emptyArray, Platform.LONG_ARRAY_OFFSET, 0);
Assert.assertFalse(loc.isDefined());
- Assert.assertFalse(loc.putNewKey(
+ Assert.assertFalse(loc.append(
emptyArray, Platform.LONG_ARRAY_OFFSET, 0, emptyArray, Platform.LONG_ARRAY_OFFSET, 0));
} finally {
map.free();
@@ -535,7 +535,7 @@ public abstract class AbstractBytesToBytesMapSuite {
final long[] arr = new long[]{i};
final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8);
success =
- loc.putNewKey(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
+ loc.append(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
if (!success) {
break;
}
@@ -556,7 +556,7 @@ public abstract class AbstractBytesToBytesMapSuite {
for (i = 0; i < 1024; i++) {
final long[] arr = new long[]{i};
final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8);
- loc.putNewKey(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
+ loc.append(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
}
BytesToBytesMap.MapIterator iter = map.iterator();
for (i = 0; i < 100; i++) {
@@ -587,6 +587,44 @@ public abstract class AbstractBytesToBytesMapSuite {
}
@Test
+ public void multipleValuesForSameKey() {
+ BytesToBytesMap map =
+ new BytesToBytesMap(taskMemoryManager, blockManager, serializerManager, 1, 0.75, 1024, false);
+ try {
+ int i;
+ for (i = 0; i < 1024; i++) {
+ final long[] arr = new long[]{i};
+ map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8)
+ .append(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
+ }
+ assert map.numKeys() == 1024;
+ assert map.numValues() == 1024;
+ for (i = 0; i < 1024; i++) {
+ final long[] arr = new long[]{i};
+ map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8)
+ .append(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
+ }
+ assert map.numKeys() == 1024;
+ assert map.numValues() == 2048;
+ for (i = 0; i < 1024; i++) {
+ final long[] arr = new long[]{i};
+ final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8);
+ assert loc.isDefined();
+ assert loc.nextValue();
+ assert !loc.nextValue();
+ }
+ BytesToBytesMap.MapIterator iter = map.iterator();
+ for (i = 0; i < 2048; i++) {
+ assert iter.hasNext();
+ final BytesToBytesMap.Location loc = iter.next();
+ assert loc.isDefined();
+ }
+ } finally {
+ map.free();
+ }
+ }
+
+ @Test
public void initialCapacityBoundsChecking() {
try {
new BytesToBytesMap(taskMemoryManager, 0, PAGE_SIZE_BYTES);
@@ -608,7 +646,7 @@ public abstract class AbstractBytesToBytesMapSuite {
@Test
public void testPeakMemoryUsed() {
- final long recordLengthBytes = 24;
+ final long recordLengthBytes = 32;
final long pageSizeBytes = 256 + 8; // 8 bytes for end-of-page marker
final long numRecordsPerPage = (pageSizeBytes - 8) / recordLengthBytes;
final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1024, pageSizeBytes);
@@ -622,7 +660,7 @@ public abstract class AbstractBytesToBytesMapSuite {
try {
for (long i = 0; i < numRecordsPerPage * 10; i++) {
final long[] value = new long[]{i};
- map.lookup(value, Platform.LONG_ARRAY_OFFSET, 8).putNewKey(
+ map.lookup(value, Platform.LONG_ARRAY_OFFSET, 8).append(
value,
Platform.LONG_ARRAY_OFFSET,
8,
diff --git a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json
index 5bbb4ceb97..1a13233133 100644
--- a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json
@@ -2,6 +2,9 @@
"id" : "local-1430917381534",
"name" : "Spark shell",
"attempts" : [ {
+ "startTimeEpoch" : 1430917380893,
+ "endTimeEpoch" : 1430917391398,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-05-06T13:03:00.893GMT",
"endTime" : "2015-05-06T13:03:11.398GMT",
"lastUpdated" : "",
@@ -14,6 +17,9 @@
"name" : "Spark shell",
"attempts" : [ {
"attemptId" : "2",
+ "startTimeEpoch" : 1430917380893,
+ "endTimeEpoch" : 1430917380950,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-05-06T13:03:00.893GMT",
"endTime" : "2015-05-06T13:03:00.950GMT",
"lastUpdated" : "",
@@ -22,6 +28,9 @@
"completed" : true
}, {
"attemptId" : "1",
+ "startTimeEpoch" : 1430917380880,
+ "endTimeEpoch" : 1430917380890,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-05-06T13:03:00.880GMT",
"endTime" : "2015-05-06T13:03:00.890GMT",
"lastUpdated" : "",
@@ -34,6 +43,9 @@
"name" : "Spark shell",
"attempts" : [ {
"attemptId" : "2",
+ "startTimeEpoch" : 1426633910242,
+ "endTimeEpoch" : 1426633945177,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-03-17T23:11:50.242GMT",
"endTime" : "2015-03-17T23:12:25.177GMT",
"lastUpdated" : "",
@@ -42,6 +54,9 @@
"completed" : true
}, {
"attemptId" : "1",
+ "startTimeEpoch" : 1426533910242,
+ "endTimeEpoch" : 1426533945177,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-03-16T19:25:10.242GMT",
"endTime" : "2015-03-16T19:25:45.177GMT",
"lastUpdated" : "",
@@ -53,6 +68,9 @@
"id" : "local-1425081759269",
"name" : "Spark shell",
"attempts" : [ {
+ "startTimeEpoch" : 1425081758277,
+ "endTimeEpoch" : 1425081766912,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-02-28T00:02:38.277GMT",
"endTime" : "2015-02-28T00:02:46.912GMT",
"lastUpdated" : "",
@@ -64,6 +82,9 @@
"id" : "local-1422981780767",
"name" : "Spark shell",
"attempts" : [ {
+ "startTimeEpoch" : 1422981779720,
+ "endTimeEpoch" : 1422981788731,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-02-03T16:42:59.720GMT",
"endTime" : "2015-02-03T16:43:08.731GMT",
"lastUpdated" : "",
@@ -75,6 +96,9 @@
"id" : "local-1422981759269",
"name" : "Spark shell",
"attempts" : [ {
+ "startTimeEpoch" : 1422981758277,
+ "endTimeEpoch" : 1422981766912,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-02-03T16:42:38.277GMT",
"endTime" : "2015-02-03T16:42:46.912GMT",
"lastUpdated" : "",
diff --git a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json
index 5bbb4ceb97..1a13233133 100644
--- a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json
@@ -2,6 +2,9 @@
"id" : "local-1430917381534",
"name" : "Spark shell",
"attempts" : [ {
+ "startTimeEpoch" : 1430917380893,
+ "endTimeEpoch" : 1430917391398,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-05-06T13:03:00.893GMT",
"endTime" : "2015-05-06T13:03:11.398GMT",
"lastUpdated" : "",
@@ -14,6 +17,9 @@
"name" : "Spark shell",
"attempts" : [ {
"attemptId" : "2",
+ "startTimeEpoch" : 1430917380893,
+ "endTimeEpoch" : 1430917380950,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-05-06T13:03:00.893GMT",
"endTime" : "2015-05-06T13:03:00.950GMT",
"lastUpdated" : "",
@@ -22,6 +28,9 @@
"completed" : true
}, {
"attemptId" : "1",
+ "startTimeEpoch" : 1430917380880,
+ "endTimeEpoch" : 1430917380890,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-05-06T13:03:00.880GMT",
"endTime" : "2015-05-06T13:03:00.890GMT",
"lastUpdated" : "",
@@ -34,6 +43,9 @@
"name" : "Spark shell",
"attempts" : [ {
"attemptId" : "2",
+ "startTimeEpoch" : 1426633910242,
+ "endTimeEpoch" : 1426633945177,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-03-17T23:11:50.242GMT",
"endTime" : "2015-03-17T23:12:25.177GMT",
"lastUpdated" : "",
@@ -42,6 +54,9 @@
"completed" : true
}, {
"attemptId" : "1",
+ "startTimeEpoch" : 1426533910242,
+ "endTimeEpoch" : 1426533945177,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-03-16T19:25:10.242GMT",
"endTime" : "2015-03-16T19:25:45.177GMT",
"lastUpdated" : "",
@@ -53,6 +68,9 @@
"id" : "local-1425081759269",
"name" : "Spark shell",
"attempts" : [ {
+ "startTimeEpoch" : 1425081758277,
+ "endTimeEpoch" : 1425081766912,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-02-28T00:02:38.277GMT",
"endTime" : "2015-02-28T00:02:46.912GMT",
"lastUpdated" : "",
@@ -64,6 +82,9 @@
"id" : "local-1422981780767",
"name" : "Spark shell",
"attempts" : [ {
+ "startTimeEpoch" : 1422981779720,
+ "endTimeEpoch" : 1422981788731,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-02-03T16:42:59.720GMT",
"endTime" : "2015-02-03T16:43:08.731GMT",
"lastUpdated" : "",
@@ -75,6 +96,9 @@
"id" : "local-1422981759269",
"name" : "Spark shell",
"attempts" : [ {
+ "startTimeEpoch" : 1422981758277,
+ "endTimeEpoch" : 1422981766912,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-02-03T16:42:38.277GMT",
"endTime" : "2015-02-03T16:42:46.912GMT",
"lastUpdated" : "",
diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json
index 4a88eeee74..efc865919b 100644
--- a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json
@@ -2,8 +2,8 @@
"id" : "<driver>",
"hostPort" : "localhost:57971",
"isActive" : true,
- "rddBlocks" : 8,
- "memoryUsed" : 28000128,
+ "rddBlocks" : 0,
+ "memoryUsed" : 0,
"diskUsed" : 0,
"totalCores" : 0,
"maxTasks" : 0,
diff --git a/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json
index 3f80a529a0..eacf04b901 100644
--- a/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json
@@ -2,6 +2,9 @@
"id" : "local-1422981759269",
"name" : "Spark shell",
"attempts" : [ {
+ "startTimeEpoch" : 1422981758277,
+ "endTimeEpoch" : 1422981766912,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-02-03T16:42:38.277GMT",
"endTime" : "2015-02-03T16:42:46.912GMT",
"lastUpdated" : "",
diff --git a/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json
index 508bdc17ef..adad25bf17 100644
--- a/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json
@@ -2,6 +2,9 @@
"id" : "local-1422981780767",
"name" : "Spark shell",
"attempts" : [ {
+ "startTimeEpoch" : 1422981779720,
+ "endTimeEpoch" : 1422981788731,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-02-03T16:42:59.720GMT",
"endTime" : "2015-02-03T16:43:08.731GMT",
"lastUpdated" : "",
@@ -13,6 +16,9 @@
"id" : "local-1422981759269",
"name" : "Spark shell",
"attempts" : [ {
+ "startTimeEpoch" : 1422981758277,
+ "endTimeEpoch" : 1422981766912,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-02-03T16:42:38.277GMT",
"endTime" : "2015-02-03T16:42:46.912GMT",
"lastUpdated" : "",
diff --git a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json
index 5dca7d73de..a658909088 100644
--- a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json
@@ -2,6 +2,9 @@
"id" : "local-1430917381534",
"name" : "Spark shell",
"attempts" : [ {
+ "startTimeEpoch" : 1430917380893,
+ "endTimeEpoch" : 1430917391398,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-05-06T13:03:00.893GMT",
"endTime" : "2015-05-06T13:03:11.398GMT",
"lastUpdated" : "",
@@ -14,6 +17,9 @@
"name" : "Spark shell",
"attempts" : [ {
"attemptId" : "2",
+ "startTimeEpoch" : 1430917380893,
+ "endTimeEpoch" : 1430917380950,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-05-06T13:03:00.893GMT",
"endTime" : "2015-05-06T13:03:00.950GMT",
"lastUpdated" : "",
@@ -22,6 +28,9 @@
"completed" : true
}, {
"attemptId" : "1",
+ "startTimeEpoch" : 1430917380880,
+ "endTimeEpoch" : 1430917380890,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-05-06T13:03:00.880GMT",
"endTime" : "2015-05-06T13:03:00.890GMT",
"lastUpdated" : "",
@@ -34,6 +43,9 @@
"name" : "Spark shell",
"attempts" : [ {
"attemptId" : "2",
+ "startTimeEpoch" : 1426633910242,
+ "endTimeEpoch" : 1426633945177,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-03-17T23:11:50.242GMT",
"endTime" : "2015-03-17T23:12:25.177GMT",
"lastUpdated" : "",
@@ -42,6 +54,9 @@
"completed" : true
}, {
"attemptId" : "1",
+ "startTimeEpoch" : 1426533910242,
+ "endTimeEpoch" : 1426533945177,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-03-16T19:25:10.242GMT",
"endTime" : "2015-03-16T19:25:45.177GMT",
"lastUpdated" : "",
@@ -54,6 +69,9 @@
"name": "Spark shell",
"attempts": [
{
+ "startTimeEpoch" : 1425081758277,
+ "endTimeEpoch" : 1425081766912,
+ "lastUpdatedEpoch" : 0,
"startTime": "2015-02-28T00:02:38.277GMT",
"endTime": "2015-02-28T00:02:46.912GMT",
"lastUpdated" : "",
diff --git a/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json
index cca32c7910..0217facad9 100644
--- a/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json
@@ -2,6 +2,9 @@
"id" : "local-1422981780767",
"name" : "Spark shell",
"attempts" : [ {
+ "startTimeEpoch" : 1422981779720,
+ "endTimeEpoch" : 1422981788731,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-02-03T16:42:59.720GMT",
"endTime" : "2015-02-03T16:43:08.731GMT",
"lastUpdated" : "",
diff --git a/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json
index 1ea1779e83..b20a26648e 100644
--- a/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json
@@ -3,6 +3,9 @@
"name" : "Spark shell",
"attempts" : [ {
"attemptId" : "2",
+ "startTimeEpoch" : 1426633910242,
+ "endTimeEpoch" : 1426633945177,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-03-17T23:11:50.242GMT",
"endTime" : "2015-03-17T23:12:25.177GMT",
"lastUpdated" : "",
@@ -11,6 +14,9 @@
"completed" : true
}, {
"attemptId" : "1",
+ "startTimeEpoch" : 1426533910242,
+ "endTimeEpoch" : 1426533945177,
+ "lastUpdatedEpoch" : 0,
"startTime" : "2015-03-16T19:25:10.242GMT",
"endTime" : "2015-03-16T19:25:45.177GMT",
"lastUpdated" : "",
diff --git a/core/src/test/resources/HistoryServerExpectations/rdd_list_storage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/rdd_list_storage_json_expectation.json
index f79a31022d..8878e547a7 100644
--- a/core/src/test/resources/HistoryServerExpectations/rdd_list_storage_json_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/rdd_list_storage_json_expectation.json
@@ -1,9 +1 @@
-[ {
- "id" : 0,
- "name" : "0",
- "numPartitions" : 8,
- "numCachedPartitions" : 8,
- "storageLevel" : "Memory Deserialized 1x Replicated",
- "memoryUsed" : 28000128,
- "diskUsed" : 0
-} ] \ No newline at end of file
+[ ] \ No newline at end of file
diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties
index a54d27de91..fb9d9851cb 100644
--- a/core/src/test/resources/log4j.properties
+++ b/core/src/test/resources/log4j.properties
@@ -33,5 +33,4 @@ log4j.appender.console.layout=org.apache.log4j.PatternLayout
log4j.appender.console.layout.ConversionPattern=%t: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
-log4j.logger.org.spark-project.jetty=WARN
-org.spark-project.jetty.LEVEL=WARN
+log4j.logger.org.spark_project.jetty=WARN
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index ec192a8543..37879d11ca 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark
+import java.util.Properties
import java.util.concurrent.Semaphore
import javax.annotation.concurrent.GuardedBy
@@ -292,7 +293,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
dummyTask, mutable.HashMap(), mutable.HashMap(), serInstance)
// Now we're on the executors.
// Deserialize the task and assert that its accumulators are zero'ed out.
- val (_, _, taskBytes) = Task.deserializeWithDependencies(taskSer)
+ val (_, _, _, taskBytes) = Task.deserializeWithDependencies(taskSer)
val taskDeser = serInstance.deserialize[DummyTask](
taskBytes, Thread.currentThread.getContextClassLoader)
// Assert that executors see only zeros
@@ -403,6 +404,6 @@ private class SaveInfoListener extends SparkListener {
private[spark] class DummyTask(
val internalAccums: Seq[Accumulator[_]],
val externalAccums: Seq[Accumulator[_]])
- extends Task[Int](0, 0, 0, internalAccums) {
+ extends Task[Int](0, 0, 0, internalAccums, new Properties) {
override def runTask(c: TaskContext): Int = 1
}
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 3dded4d486..2110d3d770 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -198,8 +198,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
blockManager.master.getLocations(blockId).foreach { cmId =>
val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId,
blockId.toString)
- val deserialized = serializerManager.dataDeserialize[Int](blockId,
- new ChunkedByteBuffer(bytes.nioByteBuffer())).toList
+ val deserialized = serializerManager.dataDeserializeStream[Int](blockId,
+ new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream()).toList
assert(deserialized === (1 to 100).toList)
}
}
@@ -320,7 +320,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
Thread.sleep(200)
}
} catch {
- case _: Throwable => { Thread.sleep(10) }
+ case _: Throwable => Thread.sleep(10)
// Do nothing. We might see exceptions because block manager
// is racing this thread to remove entries from the driver.
}
diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
index 80a1de6065..ee6b991461 100644
--- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
@@ -928,8 +928,8 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester {
numTasks: Int,
taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty
): StageInfo = {
- new StageInfo(
- stageId, 0, "name", numTasks, Seq.empty, Seq.empty, "no details", taskLocalityPreferences)
+ new StageInfo(stageId, 0, "name", numTasks, Seq.empty, Seq.empty, "no details",
+ Seq.empty, taskLocalityPreferences)
}
private def createTaskInfo(taskId: Int, taskIndex: Int, executorId: String): TaskInfo = {
diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
index 3777d77f8f..713d5e58b4 100644
--- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
@@ -174,9 +174,9 @@ class HeartbeatReceiverSuite
val dummyExecutorEndpoint2 = new FakeExecutorEndpoint(rpcEnv)
val dummyExecutorEndpointRef1 = rpcEnv.setupEndpoint("fake-executor-1", dummyExecutorEndpoint1)
val dummyExecutorEndpointRef2 = rpcEnv.setupEndpoint("fake-executor-2", dummyExecutorEndpoint2)
- fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisterExecutorResponse](
+ fakeSchedulerBackend.driverEndpoint.askWithRetry[Boolean](
RegisterExecutor(executorId1, dummyExecutorEndpointRef1, 0, Map.empty))
- fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisterExecutorResponse](
+ fakeSchedulerBackend.driverEndpoint.askWithRetry[Boolean](
RegisterExecutor(executorId2, dummyExecutorEndpointRef2, 0, Map.empty))
heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
addExecutorAndVerify(executorId1)
@@ -255,7 +255,12 @@ class HeartbeatReceiverSuite
/**
* Dummy RPC endpoint to simulate executors.
*/
-private class FakeExecutorEndpoint(override val rpcEnv: RpcEnv) extends RpcEndpoint
+private class FakeExecutorEndpoint(override val rpcEnv: RpcEnv) extends RpcEndpoint {
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case _ =>
+ }
+}
/**
* Dummy scheduler backend to simulate executor allocation requests to the cluster manager.
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 6ffa1c8ac1..cd7d2e1570 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark
+import java.util.Properties
import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService}
import org.scalatest.Matchers
@@ -335,16 +336,16 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
// first attempt -- its successful
val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0,
- new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, metricsSystem,
- InternalAccumulator.create(sc)))
+ new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem,
+ InternalAccumulator.createAll(sc)))
val data1 = (1 to 10).map { x => x -> x}
// second attempt -- also successful. We'll write out different data,
// just to simulate the fact that the records may get written differently
// depending on what gets spilled, what gets combined, etc.
val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0,
- new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, metricsSystem,
- InternalAccumulator.create(sc)))
+ new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem,
+ InternalAccumulator.createAll(sc)))
val data2 = (11 to 20).map { x => x -> x}
// interleave writes of both attempts -- we want to test that both attempts can occur
@@ -372,8 +373,8 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
}
val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1,
- new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, metricsSystem,
- InternalAccumulator.create(sc)))
+ new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem,
+ InternalAccumulator.createAll(sc)))
val readData = reader.read().toIndexedSeq
assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq)
diff --git a/core/src/test/scala/org/apache/spark/Smuggle.scala b/core/src/test/scala/org/apache/spark/Smuggle.scala
index 9f0a1b4c25..9d9217ea1b 100644
--- a/core/src/test/scala/org/apache/spark/Smuggle.scala
+++ b/core/src/test/scala/org/apache/spark/Smuggle.scala
@@ -24,16 +24,16 @@ import scala.collection.mutable
import scala.language.implicitConversions
/**
- * Utility wrapper to "smuggle" objects into tasks while bypassing serialization.
- * This is intended for testing purposes, primarily to make locks, semaphores, and
- * other constructs that would not survive serialization available from within tasks.
- * A Smuggle reference is itself serializable, but after being serialized and
- * deserialized, it still refers to the same underlying "smuggled" object, as long
- * as it was deserialized within the same JVM. This can be useful for tests that
- * depend on the timing of task completion to be deterministic, since one can "smuggle"
- * a lock or semaphore into the task, and then the task can block until the test gives
- * the go-ahead to proceed via the lock.
- */
+ * Utility wrapper to "smuggle" objects into tasks while bypassing serialization.
+ * This is intended for testing purposes, primarily to make locks, semaphores, and
+ * other constructs that would not survive serialization available from within tasks.
+ * A Smuggle reference is itself serializable, but after being serialized and
+ * deserialized, it still refers to the same underlying "smuggled" object, as long
+ * as it was deserialized within the same JVM. This can be useful for tests that
+ * depend on the timing of task completion to be deterministic, since one can "smuggle"
+ * a lock or semaphore into the task, and then the task can block until the test gives
+ * the go-ahead to proceed via the lock.
+ */
class Smuggle[T] private(val key: Symbol) extends Serializable {
def smuggledObject: T = Smuggle.get(key)
}
@@ -41,13 +41,13 @@ class Smuggle[T] private(val key: Symbol) extends Serializable {
object Smuggle {
/**
- * Wraps the specified object to be smuggled into a serialized task without
- * being serialized itself.
- *
- * @param smuggledObject
- * @tparam T
- * @return Smuggle wrapper around smuggledObject.
- */
+ * Wraps the specified object to be smuggled into a serialized task without
+ * being serialized itself.
+ *
+ * @param smuggledObject
+ * @tparam T
+ * @return Smuggle wrapper around smuggledObject.
+ */
def apply[T](smuggledObject: T): Smuggle[T] = {
val key = Symbol(UUID.randomUUID().toString)
lock.writeLock().lock()
@@ -72,12 +72,12 @@ object Smuggle {
}
/**
- * Implicit conversion of a Smuggle wrapper to the object being smuggled.
- *
- * @param smuggle the wrapper to unpack.
- * @tparam T
- * @return the smuggled object represented by the wrapper.
- */
+ * Implicit conversion of a Smuggle wrapper to the object being smuggled.
+ *
+ * @param smuggle the wrapper to unpack.
+ * @tparam T
+ * @return the smuggled object represented by the wrapper.
+ */
implicit def unpackSmuggledObject[T](smuggle : Smuggle[T]): T = smuggle.smuggledObject
}
diff --git a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala
index 3706455c3f..8feb3dee05 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala
@@ -82,20 +82,18 @@ package object testPackage extends Assertions {
val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd"
val rddCreationLine = rddCreationSite match {
- case CALL_SITE_REGEX(func, file, line) => {
+ case CALL_SITE_REGEX(func, file, line) =>
assert(func === "makeRDD")
assert(file === "SparkContextInfoSuite.scala")
line.toInt
- }
case _ => fail("Did not match expected call site format")
}
curCallSite match {
- case CALL_SITE_REGEX(func, file, line) => {
+ case CALL_SITE_REGEX(func, file, line) =>
assert(func === "getCallSite") // this is correct because we called it from outside of Spark
assert(file === "SparkContextInfoSuite.scala")
assert(line.toInt === rddCreationLine.toInt + 2)
- }
case _ => fail("Did not match expected call site format")
}
}
diff --git a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala
index f7a13ab399..09e21646ee 100644
--- a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala
+++ b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala
@@ -35,7 +35,7 @@ class UnpersistSuite extends SparkFunSuite with LocalSparkContext {
Thread.sleep(200)
}
} catch {
- case _: Throwable => { Thread.sleep(10) }
+ case _: Throwable => Thread.sleep(10)
// Do nothing. We might see exceptions because block manager
// is racing this thread to remove entries from the driver.
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index 96cb4fd0eb..2718976992 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -199,21 +199,21 @@ class SparkSubmitSuite
val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs)
val childArgsStr = childArgs.mkString(" ")
childArgsStr should include ("--class org.SomeClass")
- childArgsStr should include ("--executor-memory 5g")
- childArgsStr should include ("--driver-memory 4g")
- childArgsStr should include ("--executor-cores 5")
childArgsStr should include ("--arg arg1 --arg arg2")
- childArgsStr should include ("--queue thequeue")
childArgsStr should include regex ("--jar .*thejar.jar")
- childArgsStr should include regex ("--addJars .*one.jar,.*two.jar,.*three.jar")
- childArgsStr should include regex ("--files .*file1.txt,.*file2.txt")
- childArgsStr should include regex ("--archives .*archive1.txt,.*archive2.txt")
mainClass should be ("org.apache.spark.deploy.yarn.Client")
classpath should have length (0)
+
+ sysProps("spark.executor.memory") should be ("5g")
+ sysProps("spark.driver.memory") should be ("4g")
+ sysProps("spark.executor.cores") should be ("5")
+ sysProps("spark.yarn.queue") should be ("thequeue")
+ sysProps("spark.yarn.dist.jars") should include regex (".*one.jar,.*two.jar,.*three.jar")
+ sysProps("spark.yarn.dist.files") should include regex (".*file1.txt,.*file2.txt")
+ sysProps("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt")
sysProps("spark.app.name") should be ("beauty")
sysProps("spark.ui.enabled") should be ("false")
sysProps("SPARK_SUBMIT") should be ("true")
- sysProps.keys should not contain ("spark.jars")
}
test("handles YARN client mode") {
@@ -249,7 +249,8 @@ class SparkSubmitSuite
sysProps("spark.executor.instances") should be ("6")
sysProps("spark.yarn.dist.files") should include regex (".*file1.txt,.*file2.txt")
sysProps("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt")
- sysProps("spark.jars") should include regex (".*one.jar,.*two.jar,.*three.jar,.*thejar.jar")
+ sysProps("spark.yarn.dist.jars") should include
+ regex (".*one.jar,.*two.jar,.*three.jar,.*thejar.jar")
sysProps("SPARK_SUBMIT") should be ("true")
sysProps("spark.ui.enabled") should be ("false")
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala
index d2e24912b5..3d39bd4a74 100644
--- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala
@@ -561,7 +561,7 @@ class StandaloneDynamicAllocationSuite
when(endpointRef.address).thenReturn(mockAddress)
val message = RegisterExecutor(id, endpointRef, 10, Map.empty)
val backend = sc.schedulerBackend.asInstanceOf[CoarseGrainedSchedulerBackend]
- backend.driverEndpoint.askWithRetry[CoarseGrainedClusterMessage](message)
+ backend.driverEndpoint.askWithRetry[Boolean](message)
}
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
index 5822261d8d..2a013aca7b 100644
--- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
@@ -140,8 +140,9 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
"stage task list from multi-attempt app json(2)" ->
"applications/local-1426533911241/2/stages/0/0/taskList",
- "rdd list storage json" -> "applications/local-1422981780767/storage/rdd",
- "one rdd storage json" -> "applications/local-1422981780767/storage/rdd/0"
+ "rdd list storage json" -> "applications/local-1422981780767/storage/rdd"
+ // Todo: enable this test when logging the even of onBlockUpdated. See: SPARK-13845
+ // "one rdd storage json" -> "applications/local-1422981780767/storage/rdd/0"
)
// run a bunch of characterization tests -- just verify the behavior is the same as what is saved
@@ -161,7 +162,9 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
val json = if (jsonOrg.indexOf("lastUpdated") >= 0) {
val subStrings = jsonOrg.split(",")
for (i <- subStrings.indices) {
- if (subStrings(i).indexOf("lastUpdated") >= 0) {
+ if (subStrings(i).indexOf("lastUpdatedEpoch") >= 0) {
+ subStrings(i) = subStrings(i).replaceAll("(\\d+)", "0")
+ } else if (subStrings(i).indexOf("lastUpdated") >= 0) {
subStrings(i) = "\"lastUpdated\":\"\""
}
}
diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
index 088b05403c..d91f50f18f 100644
--- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
@@ -285,8 +285,8 @@ class TaskMetricsSuite extends SparkFunSuite {
// set and increment values
in.setBytesRead(1L)
in.setBytesRead(2L)
- in.incRecordsReadInternal(1L)
- in.incRecordsReadInternal(2L)
+ in.incRecordsRead(1L)
+ in.incRecordsRead(2L)
in.setReadMethod(DataReadMethod.Disk)
// assert new values exist
assertValEquals(_.bytesRead, BYTES_READ, 2L)
diff --git a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala
index 0644148eae..337fd7e85e 100644
--- a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala
+++ b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala
@@ -26,7 +26,7 @@ class ConfigEntrySuite extends SparkFunSuite {
test("conf entry: int") {
val conf = new SparkConf()
- val iConf = ConfigBuilder("spark.int").intConf.withDefault(1)
+ val iConf = ConfigBuilder("spark.int").intConf.createWithDefault(1)
assert(conf.get(iConf) === 1)
conf.set(iConf, 2)
assert(conf.get(iConf) === 2)
@@ -34,21 +34,21 @@ class ConfigEntrySuite extends SparkFunSuite {
test("conf entry: long") {
val conf = new SparkConf()
- val lConf = ConfigBuilder("spark.long").longConf.withDefault(0L)
+ val lConf = ConfigBuilder("spark.long").longConf.createWithDefault(0L)
conf.set(lConf, 1234L)
assert(conf.get(lConf) === 1234L)
}
test("conf entry: double") {
val conf = new SparkConf()
- val dConf = ConfigBuilder("spark.double").doubleConf.withDefault(0.0)
+ val dConf = ConfigBuilder("spark.double").doubleConf.createWithDefault(0.0)
conf.set(dConf, 20.0)
assert(conf.get(dConf) === 20.0)
}
test("conf entry: boolean") {
val conf = new SparkConf()
- val bConf = ConfigBuilder("spark.boolean").booleanConf.withDefault(false)
+ val bConf = ConfigBuilder("spark.boolean").booleanConf.createWithDefault(false)
assert(!conf.get(bConf))
conf.set(bConf, true)
assert(conf.get(bConf))
@@ -56,7 +56,7 @@ class ConfigEntrySuite extends SparkFunSuite {
test("conf entry: optional") {
val conf = new SparkConf()
- val optionalConf = ConfigBuilder("spark.optional").intConf.optional
+ val optionalConf = ConfigBuilder("spark.optional").intConf.createOptional
assert(conf.get(optionalConf) === None)
conf.set(optionalConf, 1)
assert(conf.get(optionalConf) === Some(1))
@@ -64,7 +64,7 @@ class ConfigEntrySuite extends SparkFunSuite {
test("conf entry: fallback") {
val conf = new SparkConf()
- val parentConf = ConfigBuilder("spark.int").intConf.withDefault(1)
+ val parentConf = ConfigBuilder("spark.int").intConf.createWithDefault(1)
val confWithFallback = ConfigBuilder("spark.fallback").fallbackConf(parentConf)
assert(conf.get(confWithFallback) === 1)
conf.set(confWithFallback, 2)
@@ -74,7 +74,7 @@ class ConfigEntrySuite extends SparkFunSuite {
test("conf entry: time") {
val conf = new SparkConf()
- val time = ConfigBuilder("spark.time").timeConf(TimeUnit.SECONDS).withDefaultString("1h")
+ val time = ConfigBuilder("spark.time").timeConf(TimeUnit.SECONDS).createWithDefaultString("1h")
assert(conf.get(time) === 3600L)
conf.set(time.key, "1m")
assert(conf.get(time) === 60L)
@@ -82,7 +82,7 @@ class ConfigEntrySuite extends SparkFunSuite {
test("conf entry: bytes") {
val conf = new SparkConf()
- val bytes = ConfigBuilder("spark.bytes").bytesConf(ByteUnit.KiB).withDefaultString("1m")
+ val bytes = ConfigBuilder("spark.bytes").bytesConf(ByteUnit.KiB).createWithDefaultString("1m")
assert(conf.get(bytes) === 1024L)
conf.set(bytes.key, "1k")
assert(conf.get(bytes) === 1L)
@@ -90,7 +90,7 @@ class ConfigEntrySuite extends SparkFunSuite {
test("conf entry: string seq") {
val conf = new SparkConf()
- val seq = ConfigBuilder("spark.seq").stringConf.toSequence.withDefault(Seq())
+ val seq = ConfigBuilder("spark.seq").stringConf.toSequence.createWithDefault(Seq())
conf.set(seq.key, "1,,2, 3 , , 4")
assert(conf.get(seq) === Seq("1", "2", "3", "4"))
conf.set(seq, Seq("1", "2"))
@@ -99,7 +99,7 @@ class ConfigEntrySuite extends SparkFunSuite {
test("conf entry: int seq") {
val conf = new SparkConf()
- val seq = ConfigBuilder("spark.seq").intConf.toSequence.withDefault(Seq())
+ val seq = ConfigBuilder("spark.seq").intConf.toSequence.createWithDefault(Seq())
conf.set(seq.key, "1,,2, 3 , , 4")
assert(conf.get(seq) === Seq(1, 2, 3, 4))
conf.set(seq, Seq(1, 2))
@@ -111,7 +111,7 @@ class ConfigEntrySuite extends SparkFunSuite {
val transformationConf = ConfigBuilder("spark.transformation")
.stringConf
.transform(_.toLowerCase())
- .withDefault("FOO")
+ .createWithDefault("FOO")
assert(conf.get(transformationConf) === "foo")
conf.set(transformationConf, "BAR")
@@ -123,7 +123,7 @@ class ConfigEntrySuite extends SparkFunSuite {
val enum = ConfigBuilder("spark.enum")
.stringConf
.checkValues(Set("a", "b", "c"))
- .withDefault("a")
+ .createWithDefault("a")
assert(conf.get(enum) === "a")
conf.set(enum, "b")
@@ -138,7 +138,7 @@ class ConfigEntrySuite extends SparkFunSuite {
test("conf entry: conversion error") {
val conf = new SparkConf()
- val conversionTest = ConfigBuilder("spark.conversionTest").doubleConf.optional
+ val conversionTest = ConfigBuilder("spark.conversionTest").doubleConf.createOptional
conf.set(conversionTest.key, "abc")
val conversionError = intercept[IllegalArgumentException] {
conf.get(conversionTest)
@@ -148,7 +148,7 @@ class ConfigEntrySuite extends SparkFunSuite {
test("default value handling is null-safe") {
val conf = new SparkConf()
- val stringConf = ConfigBuilder("spark.string").stringConf.withDefault(null)
+ val stringConf = ConfigBuilder("spark.string").stringConf.createWithDefault(null)
assert(conf.get(stringConf) === null)
}
diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala
index aab70e7431..f205d4f0d6 100644
--- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala
+++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala
@@ -52,7 +52,7 @@ class ChunkedByteBufferSuite extends SparkFunSuite {
test("copy() does not affect original buffer's position") {
val chunkedByteBuffer = new ChunkedByteBuffer(Array(ByteBuffer.allocate(8)))
- chunkedByteBuffer.copy()
+ chunkedByteBuffer.copy(ByteBuffer.allocate)
assert(chunkedByteBuffer.getChunks().head.position() === 0)
}
diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
index aaca653c58..99d5b496bc 100644
--- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
@@ -71,24 +71,25 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft
*/
protected def makeMemoryStore(mm: MemoryManager): MemoryStore = {
val ms = mock(classOf[MemoryStore], RETURNS_SMART_NULLS)
- when(ms.evictBlocksToFreeSpace(any(), anyLong())).thenAnswer(evictBlocksToFreeSpaceAnswer(mm))
+ when(ms.evictBlocksToFreeSpace(any(), anyLong(), any()))
+ .thenAnswer(evictBlocksToFreeSpaceAnswer(mm))
mm.setMemoryStore(ms)
ms
}
/**
- * Simulate the part of [[MemoryStore.evictBlocksToFreeSpace]] that releases storage memory.
- *
- * This is a significant simplification of the real method, which actually drops existing
- * blocks based on the size of each block. Instead, here we simply release as many bytes
- * as needed to ensure the requested amount of free space. This allows us to set up the
- * test without relying on the [[org.apache.spark.storage.BlockManager]], which brings in
- * many other dependencies.
- *
- * Every call to this method will set a global variable, [[evictBlocksToFreeSpaceCalled]], that
- * records the number of bytes this is called with. This variable is expected to be cleared
- * by the test code later through [[assertEvictBlocksToFreeSpaceCalled]].
- */
+ * Simulate the part of [[MemoryStore.evictBlocksToFreeSpace]] that releases storage memory.
+ *
+ * This is a significant simplification of the real method, which actually drops existing
+ * blocks based on the size of each block. Instead, here we simply release as many bytes
+ * as needed to ensure the requested amount of free space. This allows us to set up the
+ * test without relying on the [[org.apache.spark.storage.BlockManager]], which brings in
+ * many other dependencies.
+ *
+ * Every call to this method will set a global variable, [[evictBlocksToFreeSpaceCalled]], that
+ * records the number of bytes this is called with. This variable is expected to be cleared
+ * by the test code later through [[assertEvictBlocksToFreeSpaceCalled]].
+ */
private def evictBlocksToFreeSpaceAnswer(mm: MemoryManager): Answer[Long] = {
new Answer[Long] {
override def answer(invocation: InvocationOnMock): Long = {
diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
index 2b5e4b80e9..362cd861cc 100644
--- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
+++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
@@ -17,6 +17,8 @@
package org.apache.spark.memory
+import java.util.Properties
+
import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl}
/**
@@ -31,6 +33,7 @@ object MemoryTestingUtils {
taskAttemptId = 0,
attemptNumber = 0,
taskMemoryManager = taskMemoryManager,
+ localProperties = new Properties,
metricsSystem = env.metricsSystem)
}
}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
index 6da18cfd49..ed15e77ff1 100644
--- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
+++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
@@ -108,11 +108,11 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi
when(blockManager.getBlockData(blockId)).thenReturn(blockBuffer)
val securityManager0 = new SecurityManager(conf0)
- val exec0 = new NettyBlockTransferService(conf0, securityManager0, numCores = 1)
+ val exec0 = new NettyBlockTransferService(conf0, securityManager0, "localhost", numCores = 1)
exec0.init(blockManager)
val securityManager1 = new SecurityManager(conf1)
- val exec1 = new NettyBlockTransferService(conf1, securityManager1, numCores = 1)
+ val exec1 = new NettyBlockTransferService(conf1, securityManager1, "localhost", numCores = 1)
exec1.init(blockManager)
val result = fetchBlock(exec0, exec1, "1", blockId) match {
diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala
index cc1a9e0287..f3c156e4f7 100644
--- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala
+++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala
@@ -80,7 +80,7 @@ class NettyBlockTransferServiceSuite
.set("spark.blockManager.port", port.toString)
val securityManager = new SecurityManager(conf)
val blockDataManager = mock(classOf[BlockDataManager])
- val service = new NettyBlockTransferService(conf, securityManager, numCores = 1)
+ val service = new NettyBlockTransferService(conf, securityManager, "localhost", numCores = 1)
service.init(blockDataManager)
service
}
diff --git a/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala b/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala
new file mode 100644
index 0000000000..a79f5b4d74
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.partial
+
+import org.apache.spark._
+import org.apache.spark.util.StatCounter
+
+class SumEvaluatorSuite extends SparkFunSuite with SharedSparkContext {
+
+ test("correct handling of count 1") {
+
+ // setup
+ val counter = new StatCounter(List(2.0))
+ // count of 10 because it's larger than 1,
+ // and 0.95 because that's the default
+ val evaluator = new SumEvaluator(10, 0.95)
+ // arbitrarily assign id 1
+ evaluator.merge(1, counter)
+
+ // execute
+ val res = evaluator.currentResult()
+ // 38.0 - 7.1E-15 because that's how the maths shakes out
+ val targetMean = 38.0 - 7.1E-15
+
+ // Sanity check that equality works on BoundedDouble
+ assert(new BoundedDouble(2.0, 0.95, 1.1, 1.2) == new BoundedDouble(2.0, 0.95, 1.1, 1.2))
+
+ // actual test
+ assert(res ==
+ new BoundedDouble(targetMean, 0.950, Double.NegativeInfinity, Double.PositiveInfinity))
+ }
+
+ test("correct handling of count 0") {
+
+ // setup
+ val counter = new StatCounter(List())
+ // count of 10 because it's larger than 0,
+ // and 0.95 because that's the default
+ val evaluator = new SumEvaluator(10, 0.95)
+ // arbitrarily assign id 1
+ evaluator.merge(1, counter)
+
+ // execute
+ val res = evaluator.currentResult()
+ // assert
+ assert(res == new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity))
+ }
+
+ test("correct handling of NaN") {
+
+ // setup
+ val counter = new StatCounter(List(1, Double.NaN, 2))
+ // count of 10 because it's larger than 0,
+ // and 0.95 because that's the default
+ val evaluator = new SumEvaluator(10, 0.95)
+ // arbitrarily assign id 1
+ evaluator.merge(1, counter)
+
+ // execute
+ val res = evaluator.currentResult()
+ // assert - note semantics of == in face of NaN
+ assert(res.mean.isNaN)
+ assert(res.confidence == 0.95)
+ assert(res.low == Double.NegativeInfinity)
+ assert(res.high == Double.PositiveInfinity)
+ }
+
+ test("correct handling of > 1 values") {
+
+ // setup
+ val counter = new StatCounter(List(1, 3, 2))
+ // count of 10 because it's larger than 0,
+ // and 0.95 because that's the default
+ val evaluator = new SumEvaluator(10, 0.95)
+ // arbitrarily assign id 1
+ evaluator.merge(1, counter)
+
+ // execute
+ val res = evaluator.currentResult()
+
+ // These vals because that's how the maths shakes out
+ val targetMean = 78.0
+ val targetLow = -117.617 + 2.732357258139473E-5
+ val targetHigh = 273.617 - 2.7323572624027292E-5
+ val target = new BoundedDouble(targetMean, 0.95, targetLow, targetHigh)
+
+
+ // check that values are within expected tolerance of expectation
+ assert(res == target)
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
index 132a5fa9a8..cb0de1c6be 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
@@ -29,6 +29,8 @@ class MockSampler extends RandomSampler[Long, Long] {
s = seed
}
+ override def sample(): Int = 1
+
override def sample(items: Iterator[Long]): Iterator[Long] = {
Iterator(s)
}
diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
index 43e61241b6..cebac2097f 100644
--- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
@@ -127,9 +127,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
override val rpcEnv = env
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
- case msg: String => {
+ case msg: String =>
context.reply(msg)
- }
}
})
val reply = rpcEndpointRef.askWithRetry[String]("hello")
@@ -141,9 +140,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
override val rpcEnv = env
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
- case msg: String => {
+ case msg: String =>
context.reply(msg)
- }
}
})
@@ -164,10 +162,9 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
override val rpcEnv = env
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
- case msg: String => {
+ case msg: String =>
Thread.sleep(100)
context.reply(msg)
- }
}
})
@@ -317,10 +314,9 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
override val rpcEnv = env
override def receive: PartialFunction[Any, Unit] = {
- case m => {
+ case m =>
self
callSelfSuccessfully = true
- }
}
})
@@ -682,9 +678,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
override val rpcEnv = localEnv
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
- case msg: String => {
+ case msg: String =>
context.reply(msg)
- }
}
})
val rpcEndpointRef = remoteEnv.setupEndpointRef(localEnv.address, "ask-authentication")
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 2293c11dad..fd96fb04f8 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -1144,7 +1144,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
// SPARK-9809 -- this stage is submitted without a task for each partition (because some of
// the shuffle map output is still available from stage 0); make sure we've still got internal
// accumulators setup
- assert(scheduler.stageIdToStage(2).internalAccumulators.nonEmpty)
+ assert(scheduler.stageIdToStage(2).latestInfo.internalAccumulators.nonEmpty)
completeShuffleMapStageSuccessfully(2, 0, 2)
completeNextResultStageWithSuccess(3, 1, idx => idx + 1234)
assert(results === Map(0 -> 1234, 1 -> 1235))
diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
index f7e16af9d3..e3e6df6831 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
@@ -17,12 +17,14 @@
package org.apache.spark.scheduler
+import java.util.Properties
+
import org.apache.spark.TaskContext
class FakeTask(
stageId: Int,
prefLocs: Seq[TaskLocation] = Nil)
- extends Task[Int](stageId, 0, 0, Seq.empty) {
+ extends Task[Int](stageId, 0, 0, Seq.empty, new Properties) {
override def runTask(context: TaskContext): Int = 0
override def preferredLocations: Seq[TaskLocation] = prefLocs
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
index 1dca4bd89f..76a7087645 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
@@ -18,6 +18,7 @@
package org.apache.spark.scheduler
import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
+import java.util.Properties
import org.apache.spark.TaskContext
@@ -25,7 +26,7 @@ import org.apache.spark.TaskContext
* A Task implementation that fails to serialize.
*/
private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int)
- extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) {
+ extends Task[Array[Byte]](stageId, 0, 0, Seq.empty, new Properties) {
override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte]
override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]()
diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala
index 9f41aca8a1..601f1c378c 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala
@@ -38,7 +38,7 @@ class OutputCommitCoordinatorIntegrationSuite
super.beforeAll()
val conf = new SparkConf()
.set("master", "local[2,4]")
- .set("spark.speculation", "true")
+ .set("spark.hadoop.outputCommitCoordination.enabled", "true")
.set("spark.hadoop.mapred.output.committer.class",
classOf[ThrowExceptionOnFirstAttemptOutputCommitter].getCanonicalName)
sc = new SparkContext("local[2, 4]", "test", conf)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
index c461da65bd..8e509de767 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
@@ -77,7 +77,7 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter {
val conf = new SparkConf()
.setMaster("local[4]")
.setAppName(classOf[OutputCommitCoordinatorSuite].getSimpleName)
- .set("spark.speculation", "true")
+ .set("spark.hadoop.outputCommitCoordination.enabled", "true")
sc = new SparkContext(conf) {
override private[spark] def createSparkEnv(
conf: SparkConf,
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index 58d217ffef..b854d742b5 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
import org.scalatest.Matchers
-import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite}
+import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.util.{ResetSystemProperties, RpcUtils}
@@ -377,13 +377,18 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
}
test("registering listeners via spark.extraListeners") {
+ val listeners = Seq(
+ classOf[ListenerThatAcceptsSparkConf],
+ classOf[FirehoseListenerThatAcceptsSparkConf],
+ classOf[BasicJobCounter])
val conf = new SparkConf().setMaster("local").setAppName("test")
- .set("spark.extraListeners", classOf[ListenerThatAcceptsSparkConf].getName + "," +
- classOf[BasicJobCounter].getName)
+ .set("spark.extraListeners", listeners.map(_.getName).mkString(","))
sc = new SparkContext(conf)
sc.listenerBus.listeners.asScala.count(_.isInstanceOf[BasicJobCounter]) should be (1)
sc.listenerBus.listeners.asScala
.count(_.isInstanceOf[ListenerThatAcceptsSparkConf]) should be (1)
+ sc.listenerBus.listeners.asScala
+ .count(_.isInstanceOf[FirehoseListenerThatAcceptsSparkConf]) should be (1)
}
/**
@@ -476,3 +481,11 @@ private class ListenerThatAcceptsSparkConf(conf: SparkConf) extends SparkListene
var count = 0
override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1
}
+
+private class FirehoseListenerThatAcceptsSparkConf(conf: SparkConf) extends SparkFirehoseListener {
+ var count = 0
+ override def onEvent(event: SparkListenerEvent): Unit = event match {
+ case job: SparkListenerJobEnd => count += 1
+ case _ =>
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 5df541e5a5..5ca0c6419d 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -17,12 +17,14 @@
package org.apache.spark.scheduler
+import java.util.Properties
+
import org.mockito.Matchers.any
import org.mockito.Mockito._
import org.scalatest.BeforeAndAfter
import org.apache.spark._
-import org.apache.spark.executor.TaskMetricsSuite
+import org.apache.spark.executor.{Executor, TaskMetricsSuite}
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.source.JvmSource
import org.apache.spark.network.util.JavaUtils
@@ -59,7 +61,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
val func = (c: TaskContext, i: Iterator[String]) => i.next()
val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func))))
- val task = new ResultTask[String, String](0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0)
+ val task = new ResultTask[String, String](
+ 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties)
intercept[RuntimeException] {
task.run(0, 0, null)
}
@@ -80,7 +83,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
val func = (c: TaskContext, i: Iterator[String]) => i.next()
val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func))))
- val task = new ResultTask[String, String](0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0)
+ val task = new ResultTask[String, String](
+ 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties)
intercept[RuntimeException] {
task.run(0, 0, null)
}
@@ -171,9 +175,10 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
val initialAccums = InternalAccumulator.createAll()
// Create a dummy task. We won't end up running this; we just want to collect
// accumulator updates from it.
- val task = new Task[Int](0, 0, 0, Seq.empty[Accumulator[_]]) {
+ val task = new Task[Int](0, 0, 0, Seq.empty[Accumulator[_]], new Properties) {
context = new TaskContextImpl(0, 0, 0L, 0,
new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
+ new Properties,
SparkEnv.get.metricsSystem,
initialAccums)
context.taskMetrics.registerAccumulator(acc1)
@@ -190,6 +195,17 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
TaskMetricsSuite.assertUpdatesEquals(accumUpdates3, accumUpdates4)
}
+ test("localProperties are propagated to executors correctly") {
+ sc = new SparkContext("local", "test")
+ sc.setLocalProperty("testPropKey", "testPropValue")
+ val res = sc.parallelize(Array(1), 1).map(i => i).map(i => {
+ val inTask = TaskContext.get().getLocalProperty("testPropKey")
+ val inDeser = Executor.taskDeserializationProps.get().getProperty("testPropKey")
+ s"$inTask,$inDeser"
+ }).collect()
+ assert(res === Array("testPropValue,testPropValue"))
+ }
+
}
private object TaskContextSuite {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 167d3fd2e4..ade8e84d84 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.scheduler
-import java.util.Random
+import java.util.{Properties, Random}
import scala.collection.Map
import scala.collection.mutable
@@ -138,7 +138,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex
/**
* A Task implementation that results in a large serialized task.
*/
-class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) {
+class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, Seq.empty, new Properties) {
val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024)
val random = new Random(0)
random.nextBytes(randomBuffer)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala
index dbef6868f2..a32423dc4f 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala
@@ -136,4 +136,40 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi
capture.capture()
)
}
+
+ test("escapes commandline args for the shell") {
+ val conf = new SparkConf()
+ conf.setMaster("mesos://localhost:5050")
+ conf.setAppName("spark mesos")
+ val scheduler = new MesosClusterScheduler(
+ new BlackHoleMesosClusterPersistenceEngineFactory, conf) {
+ override def start(): Unit = { ready = true }
+ }
+ val escape = scheduler.shellEscape _
+ def wrapped(str: String): String = "\"" + str + "\""
+
+ // Wrapped in quotes
+ assert(escape("'should be left untouched'") === "'should be left untouched'")
+ assert(escape("\"should be left untouched\"") === "\"should be left untouched\"")
+
+ // Harmless
+ assert(escape("") === "")
+ assert(escape("harmless") === "harmless")
+ assert(escape("har-m.l3ss") === "har-m.l3ss")
+
+ // Special Chars escape
+ assert(escape("should escape this \" quote") === wrapped("should escape this \\\" quote"))
+ assert(escape("shouldescape\"quote") === wrapped("shouldescape\\\"quote"))
+ assert(escape("should escape this $ dollar") === wrapped("should escape this \\$ dollar"))
+ assert(escape("should escape this ` backtick") === wrapped("should escape this \\` backtick"))
+ assert(escape("""should escape this \ backslash""")
+ === wrapped("""should escape this \\ backslash"""))
+ assert(escape("""\"?""") === wrapped("""\\\"?"""))
+
+
+ // Special Chars no escape only wrap
+ List(" ", "'", "<", ">", "&", "|", "?", "*", ";", "!", "#", "(", ")").foreach(char => {
+ assert(escape(s"onlywrap${char}this") === wrapped(s"onlywrap${char}this"))
+ })
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala
index 85437b2f80..ceb3a52983 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala
@@ -59,10 +59,10 @@ class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoS
test("parse a non-empty constraint string correctly") {
val expectedMap = Map(
- "tachyon" -> Set("true"),
+ "os" -> Set("centos7"),
"zone" -> Set("us-east-1a", "us-east-1b")
)
- utils.parseConstraintString("tachyon:true;zone:us-east-1a,us-east-1b") should be (expectedMap)
+ utils.parseConstraintString("os:centos7;zone:us-east-1a,us-east-1b") should be (expectedMap)
}
test("parse an empty constraint string correctly") {
@@ -71,35 +71,35 @@ class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoS
test("throw an exception when the input is malformed") {
an[IllegalArgumentException] should be thrownBy
- utils.parseConstraintString("tachyon;zone:us-east")
+ utils.parseConstraintString("os;zone:us-east")
}
test("empty values for attributes' constraints matches all values") {
- val constraintsStr = "tachyon:"
+ val constraintsStr = "os:"
val parsedConstraints = utils.parseConstraintString(constraintsStr)
- parsedConstraints shouldBe Map("tachyon" -> Set())
+ parsedConstraints shouldBe Map("os" -> Set())
val zoneSet = Value.Set.newBuilder().addItem("us-east-1a").addItem("us-east-1b").build()
- val noTachyonOffer = Map("zone" -> zoneSet)
- val tachyonTrueOffer = Map("tachyon" -> Value.Text.newBuilder().setValue("true").build())
- val tachyonFalseOffer = Map("tachyon" -> Value.Text.newBuilder().setValue("false").build())
+ val noOsOffer = Map("zone" -> zoneSet)
+ val centosOffer = Map("os" -> Value.Text.newBuilder().setValue("centos").build())
+ val ubuntuOffer = Map("os" -> Value.Text.newBuilder().setValue("ubuntu").build())
- utils.matchesAttributeRequirements(parsedConstraints, noTachyonOffer) shouldBe false
- utils.matchesAttributeRequirements(parsedConstraints, tachyonTrueOffer) shouldBe true
- utils.matchesAttributeRequirements(parsedConstraints, tachyonFalseOffer) shouldBe true
+ utils.matchesAttributeRequirements(parsedConstraints, noOsOffer) shouldBe false
+ utils.matchesAttributeRequirements(parsedConstraints, centosOffer) shouldBe true
+ utils.matchesAttributeRequirements(parsedConstraints, ubuntuOffer) shouldBe true
}
test("subset match is performed for set attributes") {
val supersetConstraint = Map(
- "tachyon" -> Value.Text.newBuilder().setValue("true").build(),
+ "os" -> Value.Text.newBuilder().setValue("ubuntu").build(),
"zone" -> Value.Set.newBuilder()
.addItem("us-east-1a")
.addItem("us-east-1b")
.addItem("us-east-1c")
.build())
- val zoneConstraintStr = "tachyon:;zone:us-east-1a,us-east-1c"
+ val zoneConstraintStr = "os:;zone:us-east-1a,us-east-1c"
val parsedConstraints = utils.parseConstraintString(zoneConstraintStr)
utils.matchesAttributeRequirements(parsedConstraints, supersetConstraint) shouldBe true
@@ -131,10 +131,10 @@ class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoS
}
test("equality match is performed for text attributes") {
- val offerAttribs = Map("tachyon" -> Value.Text.newBuilder().setValue("true").build())
+ val offerAttribs = Map("os" -> Value.Text.newBuilder().setValue("centos7").build())
- val trueConstraint = utils.parseConstraintString("tachyon:true")
- val falseConstraint = utils.parseConstraintString("tachyon:false")
+ val trueConstraint = utils.parseConstraintString("os:centos7")
+ val falseConstraint = utils.parseConstraintString("os:ubuntu")
utils.matchesAttributeRequirements(trueConstraint, offerAttribs) shouldBe true
utils.matchesAttributeRequirements(falseConstraint, offerAttribs) shouldBe false
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
index 7ee76aa4c6..9d1bd7ec89 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.storage
+import java.util.Properties
+
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.language.implicitConversions
import scala.reflect.ClassTag
@@ -58,7 +60,8 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach {
private def withTaskId[T](taskAttemptId: Long)(block: => T): T = {
try {
- TaskContext.setTaskContext(new TaskContextImpl(0, 0, taskAttemptId, 0, null, null))
+ TaskContext.setTaskContext(
+ new TaskContextImpl(0, 0, taskAttemptId, 0, null, new Properties, null))
block
} finally {
TaskContext.unset()
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
index 98e8450fa1..d26df7e760 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -27,7 +27,7 @@ import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.concurrent.Eventually._
import org.apache.spark._
-import org.apache.spark.memory.StaticMemoryManager
+import org.apache.spark.memory.UnifiedMemoryManager
import org.apache.spark.network.BlockTransferService
import org.apache.spark.network.netty.NettyBlockTransferService
import org.apache.spark.rpc.RpcEnv
@@ -60,8 +60,10 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo
private def makeBlockManager(
maxMem: Long,
name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
- val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1)
- val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1)
+ conf.set("spark.testing.memory", maxMem.toString)
+ conf.set("spark.memory.offHeap.size", maxMem.toString)
+ val transfer = new NettyBlockTransferService(conf, securityMgr, "localhost", numCores = 1)
+ val memManager = UnifiedMemoryManager(conf, numCores = 1)
val serializerManager = new SerializerManager(serializer, conf)
val store = new BlockManager(name, rpcEnv, master, serializerManager, conf,
memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
@@ -76,6 +78,9 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo
conf.set("spark.authenticate", "false")
conf.set("spark.driver.port", rpcEnv.address.port.toString)
+ conf.set("spark.testing", "true")
+ conf.set("spark.memory.fraction", "1")
+ conf.set("spark.memory.storageFraction", "1")
conf.set("spark.storage.unrollFraction", "0.4")
conf.set("spark.storage.unrollMemoryThreshold", "512")
@@ -172,6 +177,10 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo
testReplication(5, storageLevels)
}
+ test("block replication - off-heap") {
+ testReplication(2, Seq(OFF_HEAP, StorageLevel(true, true, true, false, 2)))
+ }
+
test("block replication - 2x replication without peers") {
intercept[org.scalatest.exceptions.TestFailedException] {
testReplication(1,
@@ -262,7 +271,8 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo
val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work
when(failableTransfer.hostName).thenReturn("some-hostname")
when(failableTransfer.port).thenReturn(1000)
- val memManager = new StaticMemoryManager(conf, Long.MaxValue, 10000, numCores = 1)
+ conf.set("spark.testing.memory", "10000")
+ val memManager = UnifiedMemoryManager(conf, numCores = 1)
val serializerManager = new SerializerManager(serializer, conf)
val failableStore = new BlockManager("failable-store", rpcEnv, master, serializerManager, conf,
memManager, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0)
@@ -392,10 +402,14 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo
// If the block is supposed to be in memory, then drop the copy of the block in
// this store test whether master is updated with zero memory usage this store
if (storageLevel.useMemory) {
+ val sl = if (storageLevel.useOffHeap) {
+ StorageLevel(false, true, true, false, 1)
+ } else {
+ MEMORY_ONLY_SER
+ }
// Force the block to be dropped by adding a number of dummy blocks
(1 to 10).foreach {
- i =>
- testStore.putSingle(s"dummy-block-$i", new Array[Byte](1000), MEMORY_ONLY_SER)
+ i => testStore.putSingle(s"dummy-block-$i", new Array[Byte](1000), sl)
}
(1 to 10).foreach {
i => testStore.removeBlock(s"dummy-block-$i")
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 6fc32cb30a..a1c2933584 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -34,7 +34,7 @@ import org.scalatest.concurrent.Timeouts._
import org.apache.spark._
import org.apache.spark.executor.DataReadMethod
-import org.apache.spark.memory.{MemoryMode, StaticMemoryManager}
+import org.apache.spark.memory.UnifiedMemoryManager
import org.apache.spark.network.{BlockDataManager, BlockTransferService}
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.netty.NettyBlockTransferService
@@ -74,10 +74,12 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
name: String = SparkContext.DRIVER_IDENTIFIER,
master: BlockManagerMaster = this.master,
transferService: Option[BlockTransferService] = Option.empty): BlockManager = {
+ conf.set("spark.testing.memory", maxMem.toString)
+ conf.set("spark.memory.offHeap.size", maxMem.toString)
val serializer = new KryoSerializer(conf)
val transfer = transferService
- .getOrElse(new NettyBlockTransferService(conf, securityMgr, numCores = 1))
- val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1)
+ .getOrElse(new NettyBlockTransferService(conf, securityMgr, "localhost", numCores = 1))
+ val memManager = UnifiedMemoryManager(conf, numCores = 1)
val serializerManager = new SerializerManager(serializer, conf)
val blockManager = new BlockManager(name, rpcEnv, master, serializerManager, conf,
memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
@@ -92,6 +94,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
System.setProperty("os.arch", "amd64")
conf = new SparkConf(false)
.set("spark.app.id", "test")
+ .set("spark.testing", "true")
+ .set("spark.memory.fraction", "1")
+ .set("spark.memory.storageFraction", "1")
.set("spark.kryoserializer.buffer", "1m")
.set("spark.test.useCompressedOops", "true")
.set("spark.storage.unrollFraction", "0.4")
@@ -485,7 +490,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
val blockManager = makeBlockManager(128, "exec", bmMaster)
val getLocations = PrivateMethod[Seq[BlockManagerId]]('getLocations)
val locations = blockManager invokePrivate getLocations(BroadcastBlockId(0))
- assert(locations.map(_.host) === Seq(localHost, localHost, otherHost))
+ assert(locations.map(_.host).toSet === Set(localHost, localHost, otherHost))
}
test("SPARK-9591: getRemoteBytes from another location when Exception throw") {
@@ -510,6 +515,19 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
}
}
+ test("SPARK-14252: getOrElseUpdate should still read from remote storage") {
+ store = makeBlockManager(8000, "executor1")
+ store2 = makeBlockManager(8000, "executor2")
+ val list1 = List(new Array[Byte](4000))
+ store2.putIterator(
+ "list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true)
+ assert(store.getOrElseUpdate(
+ "list1",
+ StorageLevel.MEMORY_ONLY,
+ ClassTag.Any,
+ () => throw new AssertionError("attempted to compute locally")).isLeft)
+ }
+
test("in-memory LRU storage") {
testInMemoryLRUStorage(StorageLevel.MEMORY_ONLY)
}
@@ -518,6 +536,14 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
testInMemoryLRUStorage(StorageLevel.MEMORY_ONLY_SER)
}
+ test("in-memory LRU storage with off-heap") {
+ testInMemoryLRUStorage(StorageLevel(
+ useDisk = false,
+ useMemory = true,
+ useOffHeap = true,
+ deserialized = false, replication = 1))
+ }
+
private def testInMemoryLRUStorage(storageLevel: StorageLevel): Unit = {
store = makeBlockManager(12000)
val a1 = new Array[Byte](4000)
@@ -608,6 +634,14 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = true)
}
+ test("disk and off-heap memory storage") {
+ testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = false)
+ }
+
+ test("disk and off-heap memory storage with getLocalBytes") {
+ testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = true)
+ }
+
def testDiskAndMemoryStorage(
storageLevel: StorageLevel,
getAsBytes: Boolean): Unit = {
@@ -817,12 +851,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
test("block store put failure") {
// Use Java serializer so we can create an unserializable error.
- val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1)
- val memoryManager = new StaticMemoryManager(
- conf,
- maxOnHeapExecutionMemory = Long.MaxValue,
- maxOnHeapStorageMemory = 1200,
- numCores = 1)
+ conf.set("spark.testing.memory", "1200")
+ val transfer = new NettyBlockTransferService(conf, securityMgr, "localhost", numCores = 1)
+ val memoryManager = UnifiedMemoryManager(conf, numCores = 1)
val serializerManager = new SerializerManager(new JavaSerializer(conf), conf)
store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master,
serializerManager, conf, memoryManager, mapOutputTracker,
@@ -928,6 +959,16 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
assert(!store.diskStore.contains("list3"), "list3 was in disk store")
assert(!store.diskStore.contains("list4"), "list4 was in disk store")
assert(!store.diskStore.contains("list5"), "list5 was in disk store")
+
+ // remove block - list2 should be removed from disk
+ val updatedBlocks6 = getUpdatedBlocks {
+ store.removeBlock(
+ "list2", tellMaster = true)
+ }
+ assert(updatedBlocks6.size === 1)
+ assert(updatedBlocks6.head._1 === TestBlockId("list2"))
+ assert(updatedBlocks6.head._2.storageLevel == StorageLevel.NONE)
+ assert(!store.diskStore.contains("list2"), "list2 was in disk store")
}
test("query block statuses") {
diff --git a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala
index 43e832dc02..145d432afe 100644
--- a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala
@@ -27,7 +27,7 @@ import scala.reflect.ClassTag
import org.scalatest._
import org.apache.spark._
-import org.apache.spark.memory.StaticMemoryManager
+import org.apache.spark.memory.{MemoryMode, StaticMemoryManager}
import org.apache.spark.serializer.{KryoSerializer, SerializerManager}
import org.apache.spark.storage.memory.{BlockEvictionHandler, MemoryStore, PartiallySerializedBlock, PartiallyUnrolledIterator}
import org.apache.spark.util._
@@ -86,7 +86,7 @@ class MemoryStoreSuite
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
def reserveUnrollMemoryForThisTask(memory: Long): Boolean = {
- memoryStore.reserveUnrollMemoryForThisTask(TestBlockId(""), memory)
+ memoryStore.reserveUnrollMemoryForThisTask(TestBlockId(""), memory, MemoryMode.ON_HEAP)
}
// Reserve
@@ -99,9 +99,9 @@ class MemoryStoreSuite
assert(!reserveUnrollMemoryForThisTask(1000000))
assert(memoryStore.currentUnrollMemoryForThisTask === 800) // not granted
// Release
- memoryStore.releaseUnrollMemoryForThisTask(100)
+ memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, 100)
assert(memoryStore.currentUnrollMemoryForThisTask === 700)
- memoryStore.releaseUnrollMemoryForThisTask(100)
+ memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, 100)
assert(memoryStore.currentUnrollMemoryForThisTask === 600)
// Reserve again
assert(reserveUnrollMemoryForThisTask(4400))
@@ -109,9 +109,9 @@ class MemoryStoreSuite
assert(!reserveUnrollMemoryForThisTask(20000))
assert(memoryStore.currentUnrollMemoryForThisTask === 5000) // not granted
// Release again
- memoryStore.releaseUnrollMemoryForThisTask(1000)
+ memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, 1000)
assert(memoryStore.currentUnrollMemoryForThisTask === 4000)
- memoryStore.releaseUnrollMemoryForThisTask() // release all
+ memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP) // release all
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
}
@@ -254,7 +254,7 @@ class MemoryStoreSuite
assert(blockInfoManager.lockNewBlockForWriting(
blockId,
new BlockInfo(StorageLevel.MEMORY_ONLY_SER, classTag, tellMaster = false)))
- val res = memoryStore.putIteratorAsBytes(blockId, iter, classTag)
+ val res = memoryStore.putIteratorAsBytes(blockId, iter, classTag, MemoryMode.ON_HEAP)
blockInfoManager.unlock(blockId)
res
}
@@ -312,7 +312,7 @@ class MemoryStoreSuite
assert(blockInfoManager.lockNewBlockForWriting(
"b1",
new BlockInfo(StorageLevel.MEMORY_ONLY_SER, ClassTag.Any, tellMaster = false)))
- val res = memoryStore.putIteratorAsBytes("b1", bigIterator, ClassTag.Any)
+ val res = memoryStore.putIteratorAsBytes("b1", bigIterator, ClassTag.Any, MemoryMode.ON_HEAP)
blockInfoManager.unlock("b1")
assert(res.isLeft)
assert(memoryStore.currentUnrollMemoryForThisTask > 0)
@@ -333,7 +333,7 @@ class MemoryStoreSuite
assert(blockInfoManager.lockNewBlockForWriting(
"b1",
new BlockInfo(StorageLevel.MEMORY_ONLY_SER, ClassTag.Any, tellMaster = false)))
- val res = memoryStore.putIteratorAsBytes("b1", bigIterator, ClassTag.Any)
+ val res = memoryStore.putIteratorAsBytes("b1", bigIterator, ClassTag.Any, MemoryMode.ON_HEAP)
blockInfoManager.unlock("b1")
assert(res.isLeft)
assert(memoryStore.currentUnrollMemoryForThisTask > 0)
@@ -395,7 +395,7 @@ class MemoryStoreSuite
val blockId = BlockId("rdd_3_10")
blockInfoManager.lockNewBlockForWriting(
blockId, new BlockInfo(StorageLevel.MEMORY_ONLY, ClassTag.Any, tellMaster = false))
- memoryStore.putBytes(blockId, 13000, () => {
+ memoryStore.putBytes(blockId, 13000, MemoryMode.ON_HEAP, () => {
fail("A big ByteBuffer that cannot be put into MemoryStore should not be created")
})
}
@@ -404,7 +404,7 @@ class MemoryStoreSuite
val (memoryStore, _) = makeMemoryStore(12000)
val blockId = BlockId("rdd_3_10")
var bytes: ChunkedByteBuffer = null
- memoryStore.putBytes(blockId, 10000, () => {
+ memoryStore.putBytes(blockId, 10000, MemoryMode.ON_HEAP, () => {
bytes = new ChunkedByteBuffer(ByteBuffer.allocate(10000))
bytes
})
diff --git a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala
index 14daa003bc..9835f11a2f 100644
--- a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala
@@ -82,48 +82,51 @@ class StorageStatusListenerSuite extends SparkFunSuite {
assert(listener.executorIdToStorageStatus("fat").numBlocks === 0)
}
- test("task end with updated blocks") {
+ test("updated blocks") {
val listener = new StorageStatusListener(conf)
listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L))
listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm2, 2000L))
- val taskMetrics1 = new TaskMetrics
- val taskMetrics2 = new TaskMetrics
- val block1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.DISK_ONLY, 0L, 100L))
- val block2 = (RDDBlockId(1, 2), BlockStatus(StorageLevel.DISK_ONLY, 0L, 200L))
- val block3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.DISK_ONLY, 0L, 300L))
- taskMetrics1.setUpdatedBlockStatuses(Seq(block1, block2))
- taskMetrics2.setUpdatedBlockStatuses(Seq(block3))
-
- // Task end with new blocks
+
+ val blockUpdateInfos1 = Seq(
+ BlockUpdatedInfo(bm1, RDDBlockId(1, 1), StorageLevel.DISK_ONLY, 0L, 100L),
+ BlockUpdatedInfo(bm1, RDDBlockId(1, 2), StorageLevel.DISK_ONLY, 0L, 200L)
+ )
+ val blockUpdateInfos2 =
+ Seq(BlockUpdatedInfo(bm2, RDDBlockId(4, 0), StorageLevel.DISK_ONLY, 0L, 300L))
+
+ // Add some new blocks
assert(listener.executorIdToStorageStatus("big").numBlocks === 0)
assert(listener.executorIdToStorageStatus("fat").numBlocks === 0)
- listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics1))
+ postUpdateBlock(listener, blockUpdateInfos1)
assert(listener.executorIdToStorageStatus("big").numBlocks === 2)
assert(listener.executorIdToStorageStatus("fat").numBlocks === 0)
assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1)))
assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2)))
assert(listener.executorIdToStorageStatus("fat").numBlocks === 0)
- listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo2, taskMetrics2))
+ postUpdateBlock(listener, blockUpdateInfos2)
assert(listener.executorIdToStorageStatus("big").numBlocks === 2)
assert(listener.executorIdToStorageStatus("fat").numBlocks === 1)
assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1)))
assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2)))
assert(listener.executorIdToStorageStatus("fat").containsBlock(RDDBlockId(4, 0)))
- // Task end with dropped blocks
- val droppedBlock1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.NONE, 0L, 0L))
- val droppedBlock2 = (RDDBlockId(1, 2), BlockStatus(StorageLevel.NONE, 0L, 0L))
- val droppedBlock3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.NONE, 0L, 0L))
- taskMetrics1.setUpdatedBlockStatuses(Seq(droppedBlock1, droppedBlock3))
- taskMetrics2.setUpdatedBlockStatuses(Seq(droppedBlock2, droppedBlock3))
+ // Dropped the blocks
+ val droppedBlockInfo1 = Seq(
+ BlockUpdatedInfo(bm1, RDDBlockId(1, 1), StorageLevel.NONE, 0L, 0L),
+ BlockUpdatedInfo(bm1, RDDBlockId(4, 0), StorageLevel.NONE, 0L, 0L)
+ )
+ val droppedBlockInfo2 = Seq(
+ BlockUpdatedInfo(bm2, RDDBlockId(1, 2), StorageLevel.NONE, 0L, 0L),
+ BlockUpdatedInfo(bm2, RDDBlockId(4, 0), StorageLevel.NONE, 0L, 0L)
+ )
- listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics1))
+ postUpdateBlock(listener, droppedBlockInfo1)
assert(listener.executorIdToStorageStatus("big").numBlocks === 1)
assert(listener.executorIdToStorageStatus("fat").numBlocks === 1)
assert(!listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1)))
assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2)))
assert(listener.executorIdToStorageStatus("fat").containsBlock(RDDBlockId(4, 0)))
- listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo2, taskMetrics2))
+ postUpdateBlock(listener, droppedBlockInfo2)
assert(listener.executorIdToStorageStatus("big").numBlocks === 1)
assert(listener.executorIdToStorageStatus("fat").numBlocks === 0)
assert(!listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1)))
@@ -134,15 +137,14 @@ class StorageStatusListenerSuite extends SparkFunSuite {
test("unpersist RDD") {
val listener = new StorageStatusListener(conf)
listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L))
- val taskMetrics1 = new TaskMetrics
- val taskMetrics2 = new TaskMetrics
- val block1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.DISK_ONLY, 0L, 100L))
- val block2 = (RDDBlockId(1, 2), BlockStatus(StorageLevel.DISK_ONLY, 0L, 200L))
- val block3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.DISK_ONLY, 0L, 300L))
- taskMetrics1.setUpdatedBlockStatuses(Seq(block1, block2))
- taskMetrics2.setUpdatedBlockStatuses(Seq(block3))
- listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics1))
- listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics2))
+ val blockUpdateInfos1 = Seq(
+ BlockUpdatedInfo(bm1, RDDBlockId(1, 1), StorageLevel.DISK_ONLY, 0L, 100L),
+ BlockUpdatedInfo(bm1, RDDBlockId(1, 2), StorageLevel.DISK_ONLY, 0L, 200L)
+ )
+ val blockUpdateInfos2 =
+ Seq(BlockUpdatedInfo(bm1, RDDBlockId(4, 0), StorageLevel.DISK_ONLY, 0L, 300L))
+ postUpdateBlock(listener, blockUpdateInfos1)
+ postUpdateBlock(listener, blockUpdateInfos2)
assert(listener.executorIdToStorageStatus("big").numBlocks === 3)
// Unpersist RDD
@@ -155,4 +157,11 @@ class StorageStatusListenerSuite extends SparkFunSuite {
listener.onUnpersistRDD(SparkListenerUnpersistRDD(1))
assert(listener.executorIdToStorageStatus("big").numBlocks === 0)
}
+
+ private def postUpdateBlock(
+ listener: StorageStatusListener, updateBlockInfos: Seq[BlockUpdatedInfo]): Unit = {
+ updateBlockInfos.foreach { updateBlockInfo =>
+ listener.onBlockUpdated(SparkListenerBlockUpdated(updateBlockInfo))
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index 9876bded33..7d4c0863bc 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -322,11 +322,11 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
assert(stage1Data.inputBytes == 207)
assert(stage0Data.outputBytes == 116)
assert(stage1Data.outputBytes == 208)
- assert(stage0Data.taskData.get(1234L).get.taskMetrics.get.shuffleReadMetrics.get
+ assert(stage0Data.taskData.get(1234L).get.metrics.get.shuffleReadMetrics.get
.totalBlocksFetched == 2)
- assert(stage0Data.taskData.get(1235L).get.taskMetrics.get.shuffleReadMetrics.get
+ assert(stage0Data.taskData.get(1235L).get.metrics.get.shuffleReadMetrics.get
.totalBlocksFetched == 102)
- assert(stage1Data.taskData.get(1236L).get.taskMetrics.get.shuffleReadMetrics.get
+ assert(stage1Data.taskData.get(1236L).get.metrics.get.shuffleReadMetrics.get
.totalBlocksFetched == 202)
// task that was included in a heartbeat
@@ -355,9 +355,9 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
assert(stage1Data.inputBytes == 614)
assert(stage0Data.outputBytes == 416)
assert(stage1Data.outputBytes == 616)
- assert(stage0Data.taskData.get(1234L).get.taskMetrics.get.shuffleReadMetrics.get
+ assert(stage0Data.taskData.get(1234L).get.metrics.get.shuffleReadMetrics.get
.totalBlocksFetched == 302)
- assert(stage1Data.taskData.get(1237L).get.taskMetrics.get.shuffleReadMetrics.get
+ assert(stage1Data.taskData.get(1237L).get.metrics.get.shuffleReadMetrics.get
.totalBlocksFetched == 402)
}
}
diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala
index 6b7c538ac8..7d77deeb60 100644
--- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala
@@ -106,7 +106,7 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter {
assert(storageListener.rddInfoList.size === 0)
}
- test("task end") {
+ test("block update") {
val myRddInfo0 = rddInfo0
val myRddInfo1 = rddInfo1
val myRddInfo2 = rddInfo2
@@ -120,19 +120,13 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter {
assert(!storageListener._rddInfoMap(1).isCached)
assert(!storageListener._rddInfoMap(2).isCached)
- // Task end with no updated blocks. This should not change anything.
- bus.postToAll(SparkListenerTaskEnd(0, 0, "obliteration", Success, taskInfo, new TaskMetrics))
- assert(storageListener._rddInfoMap.size === 3)
- assert(storageListener.rddInfoList.size === 0)
-
- // Task end with a few new persisted blocks, some from the same RDD
- val metrics1 = new TaskMetrics
- metrics1.setUpdatedBlockStatuses(Seq(
- (RDDBlockId(0, 100), BlockStatus(memAndDisk, 400L, 0L)),
- (RDDBlockId(0, 101), BlockStatus(memAndDisk, 0L, 400L)),
- (RDDBlockId(1, 20), BlockStatus(memAndDisk, 0L, 240L))
- ))
- bus.postToAll(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo, metrics1))
+ // Some blocks updated
+ val blockUpdateInfos = Seq(
+ BlockUpdatedInfo(bm1, RDDBlockId(0, 100), memAndDisk, 400L, 0L),
+ BlockUpdatedInfo(bm1, RDDBlockId(0, 101), memAndDisk, 0L, 400L),
+ BlockUpdatedInfo(bm1, RDDBlockId(1, 20), memAndDisk, 0L, 240L)
+ )
+ postUpdateBlocks(bus, blockUpdateInfos)
assert(storageListener._rddInfoMap(0).memSize === 400L)
assert(storageListener._rddInfoMap(0).diskSize === 400L)
assert(storageListener._rddInfoMap(0).numCachedPartitions === 2)
@@ -144,15 +138,14 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter {
assert(!storageListener._rddInfoMap(2).isCached)
assert(storageListener._rddInfoMap(2).numCachedPartitions === 0)
- // Task end with a few dropped blocks
- val metrics2 = new TaskMetrics
- metrics2.setUpdatedBlockStatuses(Seq(
- (RDDBlockId(0, 100), BlockStatus(none, 0L, 0L)),
- (RDDBlockId(1, 20), BlockStatus(none, 0L, 0L)),
- (RDDBlockId(2, 40), BlockStatus(none, 0L, 0L)), // doesn't actually exist
- (RDDBlockId(4, 80), BlockStatus(none, 0L, 0L)) // doesn't actually exist
- ))
- bus.postToAll(SparkListenerTaskEnd(2, 0, "obliteration", Success, taskInfo, metrics2))
+ // Drop some blocks
+ val blockUpdateInfos2 = Seq(
+ BlockUpdatedInfo(bm1, RDDBlockId(0, 100), none, 0L, 0L),
+ BlockUpdatedInfo(bm1, RDDBlockId(1, 20), none, 0L, 0L),
+ BlockUpdatedInfo(bm1, RDDBlockId(2, 40), none, 0L, 0L), // doesn't actually exist
+ BlockUpdatedInfo(bm1, RDDBlockId(4, 80), none, 0L, 0L) // doesn't actually exist
+ )
+ postUpdateBlocks(bus, blockUpdateInfos2)
assert(storageListener._rddInfoMap(0).memSize === 0L)
assert(storageListener._rddInfoMap(0).diskSize === 400L)
assert(storageListener._rddInfoMap(0).numCachedPartitions === 1)
@@ -169,24 +162,27 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter {
val rddInfo1 = new RDDInfo(1, "rdd1", 1, memOnly, Seq(4))
val stageInfo0 = new StageInfo(0, 0, "stage0", 1, Seq(rddInfo0), Seq.empty, "details")
val stageInfo1 = new StageInfo(1, 0, "stage1", 1, Seq(rddInfo1), Seq.empty, "details")
- val taskMetrics0 = new TaskMetrics
- val taskMetrics1 = new TaskMetrics
- val block0 = (RDDBlockId(0, 1), BlockStatus(memOnly, 100L, 0L))
- val block1 = (RDDBlockId(1, 1), BlockStatus(memOnly, 200L, 0L))
- taskMetrics0.setUpdatedBlockStatuses(Seq(block0))
- taskMetrics1.setUpdatedBlockStatuses(Seq(block1))
+ val blockUpdateInfos1 = Seq(BlockUpdatedInfo(bm1, RDDBlockId(0, 1), memOnly, 100L, 0L))
+ val blockUpdateInfos2 = Seq(BlockUpdatedInfo(bm1, RDDBlockId(1, 1), memOnly, 200L, 0L))
bus.postToAll(SparkListenerBlockManagerAdded(1L, bm1, 1000L))
bus.postToAll(SparkListenerStageSubmitted(stageInfo0))
assert(storageListener.rddInfoList.size === 0)
- bus.postToAll(SparkListenerTaskEnd(0, 0, "big", Success, taskInfo, taskMetrics0))
+ postUpdateBlocks(bus, blockUpdateInfos1)
assert(storageListener.rddInfoList.size === 1)
bus.postToAll(SparkListenerStageSubmitted(stageInfo1))
assert(storageListener.rddInfoList.size === 1)
bus.postToAll(SparkListenerStageCompleted(stageInfo0))
assert(storageListener.rddInfoList.size === 1)
- bus.postToAll(SparkListenerTaskEnd(1, 0, "small", Success, taskInfo1, taskMetrics1))
+ postUpdateBlocks(bus, blockUpdateInfos2)
assert(storageListener.rddInfoList.size === 2)
bus.postToAll(SparkListenerStageCompleted(stageInfo1))
assert(storageListener.rddInfoList.size === 2)
}
+
+ private def postUpdateBlocks(
+ bus: SparkListenerBus, blockUpdateInfos: Seq[BlockUpdatedInfo]): Unit = {
+ blockUpdateInfos.foreach { blockUpdateInfo =>
+ bus.postToAll(SparkListenerBlockUpdated(blockUpdateInfo))
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/util/CausedBySuite.scala b/core/src/test/scala/org/apache/spark/util/CausedBySuite.scala
new file mode 100644
index 0000000000..4a80e3f1f4
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/CausedBySuite.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import org.apache.spark.SparkFunSuite
+
+class CausedBySuite extends SparkFunSuite {
+
+ test("For an error without a cause, should return the error") {
+ val error = new Exception
+
+ val causedBy = error match {
+ case CausedBy(e) => e
+ }
+
+ assert(causedBy === error)
+ }
+
+ test("For an error with a cause, should return the cause of the error") {
+ val cause = new Exception
+ val error = new Exception(cause)
+
+ val causedBy = error match {
+ case CausedBy(e) => e
+ }
+
+ assert(causedBy === cause)
+ }
+
+ test("For an error with a cause that itself has a cause, return the root cause") {
+ val causeOfCause = new Exception
+ val cause = new Exception(causeOfCause)
+ val error = new Exception(cause)
+
+ val causedBy = error match {
+ case CausedBy(e) => e
+ }
+
+ assert(causedBy === causeOfCause)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala
index 280e496498..4fa9f9a8f5 100644
--- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala
@@ -201,24 +201,29 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging {
// Make sure only logging errors
val logger = Logger.getRootLogger
+ val oldLogLevel = logger.getLevel
logger.setLevel(Level.ERROR)
- logger.addAppender(mockAppender)
+ try {
+ logger.addAppender(mockAppender)
- val testOutputStream = new PipedOutputStream()
- val testInputStream = new PipedInputStream(testOutputStream)
+ val testOutputStream = new PipedOutputStream()
+ val testInputStream = new PipedInputStream(testOutputStream)
- // Close the stream before appender tries to read will cause an IOException
- testInputStream.close()
- testOutputStream.close()
- val appender = FileAppender(testInputStream, testFile, new SparkConf)
+ // Close the stream before appender tries to read will cause an IOException
+ testInputStream.close()
+ testOutputStream.close()
+ val appender = FileAppender(testInputStream, testFile, new SparkConf)
- appender.awaitTermination()
+ appender.awaitTermination()
- // If InputStream was closed without first stopping the appender, an exception will be logged
- verify(mockAppender, atLeast(1)).doAppend(loggingEventCaptor.capture)
- val loggingEvent = loggingEventCaptor.getValue
- assert(loggingEvent.getThrowableInformation !== null)
- assert(loggingEvent.getThrowableInformation.getThrowable.isInstanceOf[IOException])
+ // If InputStream was closed without first stopping the appender, an exception will be logged
+ verify(mockAppender, atLeast(1)).doAppend(loggingEventCaptor.capture)
+ val loggingEvent = loggingEventCaptor.getValue
+ assert(loggingEvent.getThrowableInformation !== null)
+ assert(loggingEvent.getThrowableInformation.getThrowable.isInstanceOf[IOException])
+ } finally {
+ logger.setLevel(oldLogLevel)
+ }
}
test("file appender async close stream gracefully") {
@@ -228,30 +233,35 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging {
// Make sure only logging errors
val logger = Logger.getRootLogger
+ val oldLogLevel = logger.getLevel
logger.setLevel(Level.ERROR)
- logger.addAppender(mockAppender)
+ try {
+ logger.addAppender(mockAppender)
- val testOutputStream = new PipedOutputStream()
- val testInputStream = new PipedInputStream(testOutputStream) with LatchedInputStream
+ val testOutputStream = new PipedOutputStream()
+ val testInputStream = new PipedInputStream(testOutputStream) with LatchedInputStream
- // Close the stream before appender tries to read will cause an IOException
- testInputStream.close()
- testOutputStream.close()
- val appender = FileAppender(testInputStream, testFile, new SparkConf)
+ // Close the stream before appender tries to read will cause an IOException
+ testInputStream.close()
+ testOutputStream.close()
+ val appender = FileAppender(testInputStream, testFile, new SparkConf)
- // Stop the appender before an IOException is called during read
- testInputStream.latchReadStarted.await()
- appender.stop()
- testInputStream.latchReadProceed.countDown()
+ // Stop the appender before an IOException is called during read
+ testInputStream.latchReadStarted.await()
+ appender.stop()
+ testInputStream.latchReadProceed.countDown()
- appender.awaitTermination()
+ appender.awaitTermination()
- // Make sure no IOException errors have been logged as a result of appender closing gracefully
- verify(mockAppender, atLeast(0)).doAppend(loggingEventCaptor.capture)
- import scala.collection.JavaConverters._
- loggingEventCaptor.getAllValues.asScala.foreach { loggingEvent =>
- assert(loggingEvent.getThrowableInformation === null
- || !loggingEvent.getThrowableInformation.getThrowable.isInstanceOf[IOException])
+ // Make sure no IOException errors have been logged as a result of appender closing gracefully
+ verify(mockAppender, atLeast(0)).doAppend(loggingEventCaptor.capture)
+ import scala.collection.JavaConverters._
+ loggingEventCaptor.getAllValues.asScala.foreach { loggingEvent =>
+ assert(loggingEvent.getThrowableInformation === null
+ || !loggingEvent.getThrowableInformation.getThrowable.isInstanceOf[IOException])
+ }
+ } finally {
+ logger.setLevel(oldLogLevel)
}
}
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index 6a2d4c9f2c..de6f408fa8 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -853,7 +853,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
if (hasHadoopInput) {
val inputMetrics = t.registerInputMetrics(DataReadMethod.Hadoop)
inputMetrics.setBytesRead(d + e + f)
- inputMetrics.incRecordsReadInternal(if (hasRecords) (d + e + f) / 100 else -1)
+ inputMetrics.incRecordsRead(if (hasRecords) (d + e + f) / 100 else -1)
} else {
val sr = t.registerTempShuffleReadMetrics()
sr.incRemoteBytesRead(b + d)
diff --git a/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala b/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala
new file mode 100644
index 0000000000..39b31f8dde
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util.concurrent.{CountDownLatch, TimeUnit}
+
+import scala.util.Random
+
+import com.google.common.util.concurrent.Uninterruptibles
+
+import org.apache.spark.SparkFunSuite
+
+class UninterruptibleThreadSuite extends SparkFunSuite {
+
+ /** Sleep millis and return true if it's interrupted */
+ private def sleep(millis: Long): Boolean = {
+ try {
+ Thread.sleep(millis)
+ false
+ } catch {
+ case _: InterruptedException =>
+ true
+ }
+ }
+
+ test("interrupt when runUninterruptibly is running") {
+ val enterRunUninterruptibly = new CountDownLatch(1)
+ @volatile var hasInterruptedException = false
+ @volatile var interruptStatusBeforeExit = false
+ val t = new UninterruptibleThread("test") {
+ override def run(): Unit = {
+ runUninterruptibly {
+ enterRunUninterruptibly.countDown()
+ hasInterruptedException = sleep(1000)
+ }
+ interruptStatusBeforeExit = Thread.interrupted()
+ }
+ }
+ t.start()
+ assert(enterRunUninterruptibly.await(10, TimeUnit.SECONDS), "await timeout")
+ t.interrupt()
+ t.join()
+ assert(hasInterruptedException === false)
+ assert(interruptStatusBeforeExit === true)
+ }
+
+ test("interrupt before runUninterruptibly runs") {
+ val interruptLatch = new CountDownLatch(1)
+ @volatile var hasInterruptedException = false
+ @volatile var interruptStatusBeforeExit = false
+ val t = new UninterruptibleThread("test") {
+ override def run(): Unit = {
+ Uninterruptibles.awaitUninterruptibly(interruptLatch, 10, TimeUnit.SECONDS)
+ try {
+ runUninterruptibly {
+ assert(false, "Should not reach here")
+ }
+ } catch {
+ case _: InterruptedException => hasInterruptedException = true
+ }
+ interruptStatusBeforeExit = Thread.interrupted()
+ }
+ }
+ t.start()
+ t.interrupt()
+ interruptLatch.countDown()
+ t.join()
+ assert(hasInterruptedException === true)
+ assert(interruptStatusBeforeExit === false)
+ }
+
+ test("nested runUninterruptibly") {
+ val enterRunUninterruptibly = new CountDownLatch(1)
+ val interruptLatch = new CountDownLatch(1)
+ @volatile var hasInterruptedException = false
+ @volatile var interruptStatusBeforeExit = false
+ val t = new UninterruptibleThread("test") {
+ override def run(): Unit = {
+ runUninterruptibly {
+ enterRunUninterruptibly.countDown()
+ Uninterruptibles.awaitUninterruptibly(interruptLatch, 10, TimeUnit.SECONDS)
+ hasInterruptedException = sleep(1)
+ runUninterruptibly {
+ if (sleep(1)) {
+ hasInterruptedException = true
+ }
+ }
+ if (sleep(1)) {
+ hasInterruptedException = true
+ }
+ }
+ interruptStatusBeforeExit = Thread.interrupted()
+ }
+ }
+ t.start()
+ assert(enterRunUninterruptibly.await(10, TimeUnit.SECONDS), "await timeout")
+ t.interrupt()
+ interruptLatch.countDown()
+ t.join()
+ assert(hasInterruptedException === false)
+ assert(interruptStatusBeforeExit === true)
+ }
+
+ test("stress test") {
+ @volatile var hasInterruptedException = false
+ val t = new UninterruptibleThread("test") {
+ override def run(): Unit = {
+ for (i <- 0 until 100) {
+ try {
+ runUninterruptibly {
+ if (sleep(Random.nextInt(10))) {
+ hasInterruptedException = true
+ }
+ runUninterruptibly {
+ if (sleep(Random.nextInt(10))) {
+ hasInterruptedException = true
+ }
+ }
+ if (sleep(Random.nextInt(10))) {
+ hasInterruptedException = true
+ }
+ }
+ Uninterruptibles.sleepUninterruptibly(Random.nextInt(10), TimeUnit.MILLISECONDS)
+ // 50% chance to clear the interrupted status
+ if (Random.nextBoolean()) {
+ Thread.interrupted()
+ }
+ } catch {
+ case _: InterruptedException =>
+ // The first runUninterruptibly may throw InterruptedException if the interrupt status
+ // is set before running `f`.
+ }
+ }
+ }
+ }
+ t.start()
+ for (i <- 0 until 400) {
+ Thread.sleep(Random.nextInt(10))
+ t.interrupt()
+ }
+ t.join()
+ assert(hasInterruptedException === false)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala
index 361ec95654..226622075a 100644
--- a/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala
@@ -17,48 +17,53 @@
package org.apache.spark.util.io
+import java.nio.ByteBuffer
+
import scala.util.Random
import org.apache.spark.SparkFunSuite
-class ByteArrayChunkOutputStreamSuite extends SparkFunSuite {
+class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite {
test("empty output") {
- val o = new ByteArrayChunkOutputStream(1024)
- assert(o.toArrays.length === 0)
+ val o = new ChunkedByteBufferOutputStream(1024, ByteBuffer.allocate)
+ assert(o.toChunkedByteBuffer.size === 0)
}
test("write a single byte") {
- val o = new ByteArrayChunkOutputStream(1024)
+ val o = new ChunkedByteBufferOutputStream(1024, ByteBuffer.allocate)
o.write(10)
- assert(o.toArrays.length === 1)
- assert(o.toArrays.head.toSeq === Seq(10.toByte))
+ val chunkedByteBuffer = o.toChunkedByteBuffer
+ assert(chunkedByteBuffer.getChunks().length === 1)
+ assert(chunkedByteBuffer.getChunks().head.array().toSeq === Seq(10.toByte))
}
test("write a single near boundary") {
- val o = new ByteArrayChunkOutputStream(10)
+ val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
o.write(new Array[Byte](9))
o.write(99)
- assert(o.toArrays.length === 1)
- assert(o.toArrays.head(9) === 99.toByte)
+ val chunkedByteBuffer = o.toChunkedByteBuffer
+ assert(chunkedByteBuffer.getChunks().length === 1)
+ assert(chunkedByteBuffer.getChunks().head.array()(9) === 99.toByte)
}
test("write a single at boundary") {
- val o = new ByteArrayChunkOutputStream(10)
+ val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
o.write(new Array[Byte](10))
o.write(99)
- assert(o.toArrays.length === 2)
- assert(o.toArrays(1).length === 1)
- assert(o.toArrays(1)(0) === 99.toByte)
+ val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
+ assert(arrays.length === 2)
+ assert(arrays(1).length === 1)
+ assert(arrays(1)(0) === 99.toByte)
}
test("single chunk output") {
val ref = new Array[Byte](8)
Random.nextBytes(ref)
- val o = new ByteArrayChunkOutputStream(10)
+ val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
o.write(ref)
- val arrays = o.toArrays
+ val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
assert(arrays.length === 1)
assert(arrays.head.length === ref.length)
assert(arrays.head.toSeq === ref.toSeq)
@@ -67,9 +72,9 @@ class ByteArrayChunkOutputStreamSuite extends SparkFunSuite {
test("single chunk output at boundary size") {
val ref = new Array[Byte](10)
Random.nextBytes(ref)
- val o = new ByteArrayChunkOutputStream(10)
+ val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
o.write(ref)
- val arrays = o.toArrays
+ val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
assert(arrays.length === 1)
assert(arrays.head.length === ref.length)
assert(arrays.head.toSeq === ref.toSeq)
@@ -78,9 +83,9 @@ class ByteArrayChunkOutputStreamSuite extends SparkFunSuite {
test("multiple chunk output") {
val ref = new Array[Byte](26)
Random.nextBytes(ref)
- val o = new ByteArrayChunkOutputStream(10)
+ val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
o.write(ref)
- val arrays = o.toArrays
+ val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
assert(arrays.length === 3)
assert(arrays(0).length === 10)
assert(arrays(1).length === 10)
@@ -94,9 +99,9 @@ class ByteArrayChunkOutputStreamSuite extends SparkFunSuite {
test("multiple chunk output at boundary size") {
val ref = new Array[Byte](30)
Random.nextBytes(ref)
- val o = new ByteArrayChunkOutputStream(10)
+ val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate)
o.write(ref)
- val arrays = o.toArrays
+ val arrays = o.toChunkedByteBuffer.getChunks().map(_.array())
assert(arrays.length === 3)
assert(arrays(0).length === 10)
assert(arrays(1).length === 10)
diff --git a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala
index 791491daf0..7eb2f56c20 100644
--- a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala
@@ -129,6 +129,13 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers {
t(m / 2)
}
+ def replacementSampling(data: Iterator[Int], sampler: PoissonSampler[Int]): Iterator[Int] = {
+ data.flatMap { item =>
+ val count = sampler.sample()
+ if (count == 0) Iterator.empty else Iterator.fill(count)(item)
+ }
+ }
+
test("utilities") {
val s1 = Array(0, 1, 1, 0, 2)
val s2 = Array(1, 0, 3, 2, 1)
@@ -189,6 +196,36 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers {
d should be > D
}
+ test("bernoulli sampling without iterator") {
+ // Tests expect maximum gap sampling fraction to be this value
+ RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
+
+ var d: Double = 0.0
+
+ val data = Iterator.from(0)
+
+ var sampler: RandomSampler[Int, Int] = new BernoulliSampler[Int](0.5)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.5)))
+ d should be < D
+
+ sampler = new BernoulliSampler[Int](0.7)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.7)))
+ d should be < D
+
+ sampler = new BernoulliSampler[Int](0.9)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.9)))
+ d should be < D
+
+ // sampling at different frequencies should show up as statistically different:
+ sampler = new BernoulliSampler[Int](0.5)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.6)))
+ d should be > D
+ }
+
test("bernoulli sampling with gap sampling optimization") {
// Tests expect maximum gap sampling fraction to be this value
RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
@@ -217,6 +254,37 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers {
d should be > D
}
+ test("bernoulli sampling (without iterator) with gap sampling optimization") {
+ // Tests expect maximum gap sampling fraction to be this value
+ RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
+
+ var d: Double = 0.0
+
+ val data = Iterator.from(0)
+
+ var sampler: RandomSampler[Int, Int] = new BernoulliSampler[Int](0.01)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)),
+ gaps(sample(Iterator.from(0), 0.01)))
+ d should be < D
+
+ sampler = new BernoulliSampler[Int](0.1)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.1)))
+ d should be < D
+
+ sampler = new BernoulliSampler[Int](0.3)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.3)))
+ d should be < D
+
+ // sampling at different frequencies should show up as statistically different:
+ sampler = new BernoulliSampler[Int](0.3)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.4)))
+ d should be > D
+ }
+
test("bernoulli boundary cases") {
val data = (1 to 100).toArray
@@ -233,6 +301,22 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers {
sampler.sample(data.iterator).toArray should be (data)
}
+ test("bernoulli (without iterator) boundary cases") {
+ val data = (1 to 100).toArray
+
+ var sampler = new BernoulliSampler[Int](0.0)
+ data.filter(_ => sampler.sample() > 0) should be (Array.empty[Int])
+
+ sampler = new BernoulliSampler[Int](1.0)
+ data.filter(_ => sampler.sample() > 0) should be (data)
+
+ sampler = new BernoulliSampler[Int](0.0 - (RandomSampler.roundingEpsilon / 2.0))
+ data.filter(_ => sampler.sample() > 0) should be (Array.empty[Int])
+
+ sampler = new BernoulliSampler[Int](1.0 + (RandomSampler.roundingEpsilon / 2.0))
+ data.filter(_ => sampler.sample() > 0) should be (data)
+ }
+
test("bernoulli data types") {
// Tests expect maximum gap sampling fraction to be this value
RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
@@ -341,6 +425,36 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers {
d should be > D
}
+ test("replacement sampling without iterator") {
+ // Tests expect maximum gap sampling fraction to be this value
+ RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
+
+ var d: Double = 0.0
+
+ val data = Iterator.from(0)
+
+ var sampler = new PoissonSampler[Int](0.5)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.5)))
+ d should be < D
+
+ sampler = new PoissonSampler[Int](0.7)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.7)))
+ d should be < D
+
+ sampler = new PoissonSampler[Int](0.9)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.9)))
+ d should be < D
+
+ // sampling at different frequencies should show up as statistically different:
+ sampler = new PoissonSampler[Int](0.5)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.6)))
+ d should be > D
+ }
+
test("replacement sampling with gap sampling") {
// Tests expect maximum gap sampling fraction to be this value
RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
@@ -369,6 +483,36 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers {
d should be > D
}
+ test("replacement sampling (without iterator) with gap sampling") {
+ // Tests expect maximum gap sampling fraction to be this value
+ RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
+
+ var d: Double = 0.0
+
+ val data = Iterator.from(0)
+
+ var sampler = new PoissonSampler[Int](0.01)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.01)))
+ d should be < D
+
+ sampler = new PoissonSampler[Int](0.1)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.1)))
+ d should be < D
+
+ sampler = new PoissonSampler[Int](0.3)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.3)))
+ d should be < D
+
+ // sampling at different frequencies should show up as statistically different:
+ sampler = new PoissonSampler[Int](0.3)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.4)))
+ d should be > D
+ }
+
test("replacement boundary cases") {
val data = (1 to 100).toArray
@@ -383,6 +527,20 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers {
sampler.sample(data.iterator).length should be > (data.length)
}
+ test("replacement (without) boundary cases") {
+ val data = (1 to 100).toArray
+
+ var sampler = new PoissonSampler[Int](0.0)
+ replacementSampling(data.iterator, sampler).toArray should be (Array.empty[Int])
+
+ sampler = new PoissonSampler[Int](0.0 - (RandomSampler.roundingEpsilon / 2.0))
+ replacementSampling(data.iterator, sampler).toArray should be (Array.empty[Int])
+
+ // sampling with replacement has no upper bound on sampling fraction
+ sampler = new PoissonSampler[Int](2.0)
+ replacementSampling(data.iterator, sampler).length should be > (data.length)
+ }
+
test("replacement data types") {
// Tests expect maximum gap sampling fraction to be this value
RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
@@ -477,6 +635,22 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers {
d should be < D
}
+ test("bernoulli partitioning sampling without iterator") {
+ var d: Double = 0.0
+
+ val data = Iterator.from(0)
+
+ var sampler = new BernoulliCellSampler[Int](0.1, 0.2)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.1)))
+ d should be < D
+
+ sampler = new BernoulliCellSampler[Int](0.1, 0.2, true)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.9)))
+ d should be < D
+ }
+
test("bernoulli partitioning boundary cases") {
val data = (1 to 100).toArray
val d = RandomSampler.roundingEpsilon / 2.0
@@ -500,6 +674,29 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers {
sampler.sample(data.iterator).toArray should be (Array.empty[Int])
}
+ test("bernoulli partitioning (without iterator) boundary cases") {
+ val data = (1 to 100).toArray
+ val d = RandomSampler.roundingEpsilon / 2.0
+
+ var sampler = new BernoulliCellSampler[Int](0.0, 0.0)
+ data.filter(_ => sampler.sample() > 0).toArray should be (Array.empty[Int])
+
+ sampler = new BernoulliCellSampler[Int](0.5, 0.5)
+ data.filter(_ => sampler.sample() > 0).toArray should be (Array.empty[Int])
+
+ sampler = new BernoulliCellSampler[Int](1.0, 1.0)
+ data.filter(_ => sampler.sample() > 0).toArray should be (Array.empty[Int])
+
+ sampler = new BernoulliCellSampler[Int](0.0, 1.0)
+ data.filter(_ => sampler.sample() > 0).toArray should be (data)
+
+ sampler = new BernoulliCellSampler[Int](0.0 - d, 1.0 + d)
+ data.filter(_ => sampler.sample() > 0).toArray should be (data)
+
+ sampler = new BernoulliCellSampler[Int](0.5, 0.5 - d)
+ data.filter(_ => sampler.sample() > 0).toArray should be (Array.empty[Int])
+ }
+
test("bernoulli partitioning data") {
val seed = rngSeed.nextLong
val data = (1 to 100).toArray
diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2
index 512675a599..023fba5369 100644
--- a/dev/deps/spark-deps-hadoop-2.2
+++ b/dev/deps/spark-deps-hadoop-2.2
@@ -2,7 +2,9 @@ JavaEWAH-0.3.2.jar
RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.jar
-antlr-runtime-3.5.2.jar
+antlr-2.7.7.jar
+antlr-runtime-3.4.jar
+antlr4-runtime-4.5.2-1.jar
aopalliance-1.0.jar
apache-log4j-extras-1.2.17.jar
arpack_combined_all-0.1.jar
@@ -10,7 +12,6 @@ asm-3.1.jar
asm-commons-3.1.jar
asm-tree-3.1.jar
avro-1.7.7.jar
-avro-ipc-1.7.7-tests.jar
avro-ipc-1.7.7.jar
avro-mapred-1.7.7-hadoop2.jar
bonecp-0.8.0.RELEASE.jar
@@ -19,8 +20,8 @@ breeze_2.11-0.11.2.jar
calcite-avatica-1.2.0-incubating.jar
calcite-core-1.2.0-incubating.jar
calcite-linq4j-1.2.0-incubating.jar
-chill-java-0.7.4.jar
-chill_2.11-0.7.4.jar
+chill-java-0.8.0.jar
+chill_2.11-0.8.0.jar
commons-beanutils-1.7.0.jar
commons-beanutils-core-1.8.0.jar
commons-cli-1.2.jar
@@ -59,6 +60,7 @@ grizzly-http-2.1.2.jar
grizzly-http-server-2.1.2.jar
grizzly-http-servlet-2.1.2.jar
grizzly-rcm-2.1.2.jar
+guava-14.0.1.jar
guice-3.0.jar
guice-servlet-3.0.jar
hadoop-annotations-2.2.0.jar
@@ -121,7 +123,7 @@ jsr305-1.3.9.jar
jta-1.1.jar
jtransforms-2.4.0.jar
jul-to-slf4j-1.7.16.jar
-kryo-2.21.jar
+kryo-shaded-3.0.3.jar
leveldbjni-all-1.8.jar
libfb303-0.9.2.jar
libthrift-0.9.2.jar
@@ -134,10 +136,10 @@ metrics-core-3.1.2.jar
metrics-graphite-3.1.2.jar
metrics-json-3.1.2.jar
metrics-jvm-3.1.2.jar
-minlog-1.2.jar
+minlog-1.3.0.jar
netty-3.8.0.Final.jar
netty-all-4.0.29.Final.jar
-objenesis-1.2.jar
+objenesis-2.1.jar
opencsv-2.3.jar
oro-2.0.8.jar
paranamer-2.6.jar
@@ -155,26 +157,24 @@ pmml-schema-1.2.7.jar
protobuf-java-2.5.0.jar
py4j-0.9.2.jar
pyrolite-4.9.jar
-reflectasm-1.07-shaded.jar
-scala-compiler-2.11.7.jar
-scala-library-2.11.7.jar
+scala-compiler-2.11.8.jar
+scala-library-2.11.8.jar
scala-parser-combinators_2.11-1.0.4.jar
-scala-reflect-2.11.7.jar
+scala-reflect-2.11.8.jar
scala-xml_2.11-1.0.2.jar
-scalap-2.11.7.jar
-servlet-api-2.5.jar
+scalap-2.11.8.jar
slf4j-api-1.7.16.jar
slf4j-log4j12-1.7.16.jar
snappy-0.2.jar
-snappy-java-1.1.2.1.jar
+snappy-java-1.1.2.4.jar
spire-macros_2.11-0.7.4.jar
spire_2.11-0.7.4.jar
stax-api-1.0-2.jar
stax-api-1.0.1.jar
stream-2.7.0.jar
+stringtemplate-3.2.1.jar
super-csv-2.2.0.jar
-univocity-parsers-1.5.6.jar
-unused-1.0.0.jar
+univocity-parsers-2.0.2.jar
xbean-asm5-shaded-4.4.jar
xmlenc-0.52.jar
xz-1.0.jar
diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3
index 31f8694fed..003c540d72 100644
--- a/dev/deps/spark-deps-hadoop-2.3
+++ b/dev/deps/spark-deps-hadoop-2.3
@@ -2,7 +2,9 @@ JavaEWAH-0.3.2.jar
RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.1.jar
-antlr-runtime-3.5.2.jar
+antlr-2.7.7.jar
+antlr-runtime-3.4.jar
+antlr4-runtime-4.5.2-1.jar
aopalliance-1.0.jar
apache-log4j-extras-1.2.17.jar
arpack_combined_all-0.1.jar
@@ -10,7 +12,6 @@ asm-3.1.jar
asm-commons-3.1.jar
asm-tree-3.1.jar
avro-1.7.7.jar
-avro-ipc-1.7.7-tests.jar
avro-ipc-1.7.7.jar
avro-mapred-1.7.7-hadoop2.jar
base64-2.3.8.jar
@@ -21,8 +22,8 @@ breeze_2.11-0.11.2.jar
calcite-avatica-1.2.0-incubating.jar
calcite-core-1.2.0-incubating.jar
calcite-linq4j-1.2.0-incubating.jar
-chill-java-0.7.4.jar
-chill_2.11-0.7.4.jar
+chill-java-0.8.0.jar
+chill_2.11-0.8.0.jar
commons-beanutils-1.7.0.jar
commons-beanutils-core-1.8.0.jar
commons-cli-1.2.jar
@@ -54,6 +55,7 @@ eigenbase-properties-1.1.5.jar
geronimo-annotation_1.0_spec-1.1.1.jar
geronimo-jaspic_1.0_spec-1.0.jar
geronimo-jta_1.1_spec-1.1.1.jar
+guava-14.0.1.jar
guice-3.0.jar
guice-servlet-3.0.jar
hadoop-annotations-2.3.0.jar
@@ -112,7 +114,7 @@ jsr305-1.3.9.jar
jta-1.1.jar
jtransforms-2.4.0.jar
jul-to-slf4j-1.7.16.jar
-kryo-2.21.jar
+kryo-shaded-3.0.3.jar
leveldbjni-all-1.8.jar
libfb303-0.9.2.jar
libthrift-0.9.2.jar
@@ -124,11 +126,11 @@ metrics-core-3.1.2.jar
metrics-graphite-3.1.2.jar
metrics-json-3.1.2.jar
metrics-jvm-3.1.2.jar
-minlog-1.2.jar
+minlog-1.3.0.jar
mx4j-3.0.2.jar
netty-3.8.0.Final.jar
netty-all-4.0.29.Final.jar
-objenesis-1.2.jar
+objenesis-2.1.jar
opencsv-2.3.jar
oro-2.0.8.jar
paranamer-2.6.jar
@@ -146,26 +148,24 @@ pmml-schema-1.2.7.jar
protobuf-java-2.5.0.jar
py4j-0.9.2.jar
pyrolite-4.9.jar
-reflectasm-1.07-shaded.jar
-scala-compiler-2.11.7.jar
-scala-library-2.11.7.jar
+scala-compiler-2.11.8.jar
+scala-library-2.11.8.jar
scala-parser-combinators_2.11-1.0.4.jar
-scala-reflect-2.11.7.jar
+scala-reflect-2.11.8.jar
scala-xml_2.11-1.0.2.jar
-scalap-2.11.7.jar
-servlet-api-2.5.jar
+scalap-2.11.8.jar
slf4j-api-1.7.16.jar
slf4j-log4j12-1.7.16.jar
snappy-0.2.jar
-snappy-java-1.1.2.1.jar
+snappy-java-1.1.2.4.jar
spire-macros_2.11-0.7.4.jar
spire_2.11-0.7.4.jar
stax-api-1.0-2.jar
stax-api-1.0.1.jar
stream-2.7.0.jar
+stringtemplate-3.2.1.jar
super-csv-2.2.0.jar
-univocity-parsers-1.5.6.jar
-unused-1.0.0.jar
+univocity-parsers-2.0.2.jar
xbean-asm5-shaded-4.4.jar
xmlenc-0.52.jar
xz-1.0.jar
diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4
index 0fa8bccab0..80fbaea222 100644
--- a/dev/deps/spark-deps-hadoop-2.4
+++ b/dev/deps/spark-deps-hadoop-2.4
@@ -2,7 +2,9 @@ JavaEWAH-0.3.2.jar
RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.1.jar
-antlr-runtime-3.5.2.jar
+antlr-2.7.7.jar
+antlr-runtime-3.4.jar
+antlr4-runtime-4.5.2-1.jar
aopalliance-1.0.jar
apache-log4j-extras-1.2.17.jar
arpack_combined_all-0.1.jar
@@ -10,7 +12,6 @@ asm-3.1.jar
asm-commons-3.1.jar
asm-tree-3.1.jar
avro-1.7.7.jar
-avro-ipc-1.7.7-tests.jar
avro-ipc-1.7.7.jar
avro-mapred-1.7.7-hadoop2.jar
base64-2.3.8.jar
@@ -21,8 +22,8 @@ breeze_2.11-0.11.2.jar
calcite-avatica-1.2.0-incubating.jar
calcite-core-1.2.0-incubating.jar
calcite-linq4j-1.2.0-incubating.jar
-chill-java-0.7.4.jar
-chill_2.11-0.7.4.jar
+chill-java-0.8.0.jar
+chill_2.11-0.8.0.jar
commons-beanutils-1.7.0.jar
commons-beanutils-core-1.8.0.jar
commons-cli-1.2.jar
@@ -54,6 +55,7 @@ eigenbase-properties-1.1.5.jar
geronimo-annotation_1.0_spec-1.1.1.jar
geronimo-jaspic_1.0_spec-1.0.jar
geronimo-jta_1.1_spec-1.1.1.jar
+guava-14.0.1.jar
guice-3.0.jar
guice-servlet-3.0.jar
hadoop-annotations-2.4.0.jar
@@ -113,7 +115,7 @@ jsr305-1.3.9.jar
jta-1.1.jar
jtransforms-2.4.0.jar
jul-to-slf4j-1.7.16.jar
-kryo-2.21.jar
+kryo-shaded-3.0.3.jar
leveldbjni-all-1.8.jar
libfb303-0.9.2.jar
libthrift-0.9.2.jar
@@ -125,11 +127,11 @@ metrics-core-3.1.2.jar
metrics-graphite-3.1.2.jar
metrics-json-3.1.2.jar
metrics-jvm-3.1.2.jar
-minlog-1.2.jar
+minlog-1.3.0.jar
mx4j-3.0.2.jar
netty-3.8.0.Final.jar
netty-all-4.0.29.Final.jar
-objenesis-1.2.jar
+objenesis-2.1.jar
opencsv-2.3.jar
oro-2.0.8.jar
paranamer-2.6.jar
@@ -147,26 +149,24 @@ pmml-schema-1.2.7.jar
protobuf-java-2.5.0.jar
py4j-0.9.2.jar
pyrolite-4.9.jar
-reflectasm-1.07-shaded.jar
-scala-compiler-2.11.7.jar
-scala-library-2.11.7.jar
+scala-compiler-2.11.8.jar
+scala-library-2.11.8.jar
scala-parser-combinators_2.11-1.0.4.jar
-scala-reflect-2.11.7.jar
+scala-reflect-2.11.8.jar
scala-xml_2.11-1.0.2.jar
-scalap-2.11.7.jar
-servlet-api-2.5.jar
+scalap-2.11.8.jar
slf4j-api-1.7.16.jar
slf4j-log4j12-1.7.16.jar
snappy-0.2.jar
-snappy-java-1.1.2.1.jar
+snappy-java-1.1.2.4.jar
spire-macros_2.11-0.7.4.jar
spire_2.11-0.7.4.jar
stax-api-1.0-2.jar
stax-api-1.0.1.jar
stream-2.7.0.jar
+stringtemplate-3.2.1.jar
super-csv-2.2.0.jar
-univocity-parsers-1.5.6.jar
-unused-1.0.0.jar
+univocity-parsers-2.0.2.jar
xbean-asm5-shaded-4.4.jar
xmlenc-0.52.jar
xz-1.0.jar
diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6
index 8d2f6e6e32..b2c2a4caec 100644
--- a/dev/deps/spark-deps-hadoop-2.6
+++ b/dev/deps/spark-deps-hadoop-2.6
@@ -2,7 +2,9 @@ JavaEWAH-0.3.2.jar
RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.1.jar
-antlr-runtime-3.5.2.jar
+antlr-2.7.7.jar
+antlr-runtime-3.4.jar
+antlr4-runtime-4.5.2-1.jar
aopalliance-1.0.jar
apache-log4j-extras-1.2.17.jar
apacheds-i18n-2.0.0-M15.jar
@@ -14,7 +16,6 @@ asm-3.1.jar
asm-commons-3.1.jar
asm-tree-3.1.jar
avro-1.7.7.jar
-avro-ipc-1.7.7-tests.jar
avro-ipc-1.7.7.jar
avro-mapred-1.7.7-hadoop2.jar
base64-2.3.8.jar
@@ -25,8 +26,8 @@ breeze_2.11-0.11.2.jar
calcite-avatica-1.2.0-incubating.jar
calcite-core-1.2.0-incubating.jar
calcite-linq4j-1.2.0-incubating.jar
-chill-java-0.7.4.jar
-chill_2.11-0.7.4.jar
+chill-java-0.8.0.jar
+chill_2.11-0.8.0.jar
commons-beanutils-1.7.0.jar
commons-beanutils-core-1.8.0.jar
commons-cli-1.2.jar
@@ -59,6 +60,7 @@ geronimo-annotation_1.0_spec-1.1.1.jar
geronimo-jaspic_1.0_spec-1.0.jar
geronimo-jta_1.1_spec-1.1.1.jar
gson-2.2.4.jar
+guava-14.0.1.jar
guice-3.0.jar
guice-servlet-3.0.jar
hadoop-annotations-2.6.0.jar
@@ -119,7 +121,7 @@ jsr305-1.3.9.jar
jta-1.1.jar
jtransforms-2.4.0.jar
jul-to-slf4j-1.7.16.jar
-kryo-2.21.jar
+kryo-shaded-3.0.3.jar
leveldbjni-all-1.8.jar
libfb303-0.9.2.jar
libthrift-0.9.2.jar
@@ -131,11 +133,11 @@ metrics-core-3.1.2.jar
metrics-graphite-3.1.2.jar
metrics-json-3.1.2.jar
metrics-jvm-3.1.2.jar
-minlog-1.2.jar
+minlog-1.3.0.jar
mx4j-3.0.2.jar
netty-3.8.0.Final.jar
netty-all-4.0.29.Final.jar
-objenesis-1.2.jar
+objenesis-2.1.jar
opencsv-2.3.jar
oro-2.0.8.jar
paranamer-2.6.jar
@@ -153,26 +155,24 @@ pmml-schema-1.2.7.jar
protobuf-java-2.5.0.jar
py4j-0.9.2.jar
pyrolite-4.9.jar
-reflectasm-1.07-shaded.jar
-scala-compiler-2.11.7.jar
-scala-library-2.11.7.jar
+scala-compiler-2.11.8.jar
+scala-library-2.11.8.jar
scala-parser-combinators_2.11-1.0.4.jar
-scala-reflect-2.11.7.jar
+scala-reflect-2.11.8.jar
scala-xml_2.11-1.0.2.jar
-scalap-2.11.7.jar
-servlet-api-2.5.jar
+scalap-2.11.8.jar
slf4j-api-1.7.16.jar
slf4j-log4j12-1.7.16.jar
snappy-0.2.jar
-snappy-java-1.1.2.1.jar
+snappy-java-1.1.2.4.jar
spire-macros_2.11-0.7.4.jar
spire_2.11-0.7.4.jar
stax-api-1.0-2.jar
stax-api-1.0.1.jar
stream-2.7.0.jar
+stringtemplate-3.2.1.jar
super-csv-2.2.0.jar
-univocity-parsers-1.5.6.jar
-unused-1.0.0.jar
+univocity-parsers-2.0.2.jar
xbean-asm5-shaded-4.4.jar
xercesImpl-2.9.1.jar
xmlenc-0.52.jar
diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7
index a114c4ae8d..71e51883d5 100644
--- a/dev/deps/spark-deps-hadoop-2.7
+++ b/dev/deps/spark-deps-hadoop-2.7
@@ -2,7 +2,9 @@ JavaEWAH-0.3.2.jar
RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.1.jar
-antlr-runtime-3.5.2.jar
+antlr-2.7.7.jar
+antlr-runtime-3.4.jar
+antlr4-runtime-4.5.2-1.jar
aopalliance-1.0.jar
apache-log4j-extras-1.2.17.jar
apacheds-i18n-2.0.0-M15.jar
@@ -14,7 +16,6 @@ asm-3.1.jar
asm-commons-3.1.jar
asm-tree-3.1.jar
avro-1.7.7.jar
-avro-ipc-1.7.7-tests.jar
avro-ipc-1.7.7.jar
avro-mapred-1.7.7-hadoop2.jar
base64-2.3.8.jar
@@ -25,8 +26,8 @@ breeze_2.11-0.11.2.jar
calcite-avatica-1.2.0-incubating.jar
calcite-core-1.2.0-incubating.jar
calcite-linq4j-1.2.0-incubating.jar
-chill-java-0.7.4.jar
-chill_2.11-0.7.4.jar
+chill-java-0.8.0.jar
+chill_2.11-0.8.0.jar
commons-beanutils-1.7.0.jar
commons-beanutils-core-1.8.0.jar
commons-cli-1.2.jar
@@ -59,6 +60,7 @@ geronimo-annotation_1.0_spec-1.1.1.jar
geronimo-jaspic_1.0_spec-1.0.jar
geronimo-jta_1.1_spec-1.1.1.jar
gson-2.2.4.jar
+guava-14.0.1.jar
guice-3.0.jar
guice-servlet-3.0.jar
hadoop-annotations-2.7.0.jar
@@ -120,7 +122,7 @@ jsr305-1.3.9.jar
jta-1.1.jar
jtransforms-2.4.0.jar
jul-to-slf4j-1.7.16.jar
-kryo-2.21.jar
+kryo-shaded-3.0.3.jar
leveldbjni-all-1.8.jar
libfb303-0.9.2.jar
libthrift-0.9.2.jar
@@ -132,11 +134,11 @@ metrics-core-3.1.2.jar
metrics-graphite-3.1.2.jar
metrics-json-3.1.2.jar
metrics-jvm-3.1.2.jar
-minlog-1.2.jar
+minlog-1.3.0.jar
mx4j-3.0.2.jar
netty-3.8.0.Final.jar
netty-all-4.0.29.Final.jar
-objenesis-1.2.jar
+objenesis-2.1.jar
opencsv-2.3.jar
oro-2.0.8.jar
paranamer-2.6.jar
@@ -154,26 +156,24 @@ pmml-schema-1.2.7.jar
protobuf-java-2.5.0.jar
py4j-0.9.2.jar
pyrolite-4.9.jar
-reflectasm-1.07-shaded.jar
-scala-compiler-2.11.7.jar
-scala-library-2.11.7.jar
+scala-compiler-2.11.8.jar
+scala-library-2.11.8.jar
scala-parser-combinators_2.11-1.0.4.jar
-scala-reflect-2.11.7.jar
+scala-reflect-2.11.8.jar
scala-xml_2.11-1.0.2.jar
-scalap-2.11.7.jar
-servlet-api-2.5.jar
+scalap-2.11.8.jar
slf4j-api-1.7.16.jar
slf4j-log4j12-1.7.16.jar
snappy-0.2.jar
-snappy-java-1.1.2.1.jar
+snappy-java-1.1.2.4.jar
spire-macros_2.11-0.7.4.jar
spire_2.11-0.7.4.jar
stax-api-1.0-2.jar
stax-api-1.0.1.jar
stream-2.7.0.jar
+stringtemplate-3.2.1.jar
super-csv-2.2.0.jar
-univocity-parsers-1.5.6.jar
-unused-1.0.0.jar
+univocity-parsers-2.0.2.jar
xbean-asm5-shaded-4.4.jar
xercesImpl-2.9.1.jar
xmlenc-0.52.jar
diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh
index dbdd42ff9e..4f7544f6ea 100755
--- a/dev/make-distribution.sh
+++ b/dev/make-distribution.sh
@@ -160,28 +160,35 @@ echo -e "\$ ${BUILD_COMMAND[@]}\n"
# Make directories
rm -rf "$DISTDIR"
-mkdir -p "$DISTDIR/lib"
+mkdir -p "$DISTDIR/jars"
echo "Spark $VERSION$GITREVSTRING built for Hadoop $SPARK_HADOOP_VERSION" > "$DISTDIR/RELEASE"
echo "Build flags: $@" >> "$DISTDIR/RELEASE"
# Copy jars
-cp "$SPARK_HOME"/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/"
-# This will fail if the -Pyarn profile is not provided
-# In this case, silence the error and ignore the return code of this command
-cp "$SPARK_HOME"/common/network-yarn/target/scala*/spark-*-yarn-shuffle.jar "$DISTDIR/lib/" &> /dev/null || :
+cp "$SPARK_HOME"/assembly/target/scala*/jars/* "$DISTDIR/jars/"
+
+# Only create the yarn directory if the yarn artifacts were build.
+if [ -f "$SPARK_HOME"/common/network-yarn/target/scala*/spark-*-yarn-shuffle.jar ]; then
+ mkdir "$DISTDIR"/yarn
+ cp "$SPARK_HOME"/common/network-yarn/target/scala*/spark-*-yarn-shuffle.jar "$DISTDIR/yarn"
+fi
# Copy examples and dependencies
mkdir -p "$DISTDIR/examples/jars"
cp "$SPARK_HOME"/examples/target/scala*/jars/* "$DISTDIR/examples/jars"
+# Deduplicate jars that have already been packaged as part of the main Spark dependencies.
+for f in "$DISTDIR/examples/jars/"*; do
+ name=$(basename "$f")
+ if [ -f "$DISTDIR/jars/$name" ]; then
+ rm "$DISTDIR/examples/jars/$name"
+ fi
+done
+
# Copy example sources (needed for python and SQL)
mkdir -p "$DISTDIR/examples/src/main"
cp -r "$SPARK_HOME"/examples/src/main "$DISTDIR/examples/src/"
-if [ "$SPARK_HIVE" == "1" ]; then
- cp "$SPARK_HOME"/lib_managed/jars/datanucleus*.jar "$DISTDIR/lib/"
-fi
-
# Copy license and ASF files
cp "$SPARK_HOME/LICENSE" "$DISTDIR"
cp -r "$SPARK_HOME/licenses" "$DISTDIR"
diff --git a/dev/mima b/dev/mima
index ea746e6f01..c355349045 100755
--- a/dev/mima
+++ b/dev/mima
@@ -25,8 +25,8 @@ FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
cd "$FWDIR"
SPARK_PROFILES="-Pyarn -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive"
-TOOLS_CLASSPATH="$(build/sbt "export tools/fullClasspath" | tail -n1)"
-OLD_DEPS_CLASSPATH="$(build/sbt $SPARK_PROFILES "export oldDeps/fullClasspath" | tail -n1)"
+TOOLS_CLASSPATH="$(build/sbt -DcopyDependencies=false "export tools/fullClasspath" | tail -n1)"
+OLD_DEPS_CLASSPATH="$(build/sbt -DcopyDependencies=false $SPARK_PROFILES "export oldDeps/fullClasspath" | tail -n1)"
rm -f .generated-mima*
@@ -36,7 +36,7 @@ java \
-cp "$TOOLS_CLASSPATH:$OLD_DEPS_CLASSPATH" \
org.apache.spark.tools.GenerateMIMAIgnore
-echo -e "q\n" | build/sbt mimaReportBinaryIssues | grep -v -e "info.*Resolving"
+echo -e "q\n" | build/sbt -DcopyDependencies=false "$@" mimaReportBinaryIssues | grep -v -e "info.*Resolving"
ret_val=$?
if [ $ret_val != 0 ]; then
diff --git a/dev/run-tests.py b/dev/run-tests.py
index c2944747ee..cbe347274e 100755
--- a/dev/run-tests.py
+++ b/dev/run-tests.py
@@ -350,7 +350,7 @@ def build_spark_sbt(hadoop_version):
def build_spark_assembly_sbt(hadoop_version):
# Enable all of the profiles for the build:
build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags
- sbt_goals = ["assembly/assembly"]
+ sbt_goals = ["assembly/package"]
profiles_and_goals = build_profiles + sbt_goals
print("[info] Building Spark assembly (w/Hive 1.2.1) using SBT with these arguments: ",
" ".join(profiles_and_goals))
@@ -371,9 +371,10 @@ def build_apache_spark(build_tool, hadoop_version):
build_spark_sbt(hadoop_version)
-def detect_binary_inop_with_mima():
+def detect_binary_inop_with_mima(hadoop_version):
+ build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags
set_title_and_block("Detecting binary incompatibilities with MiMa", "BLOCK_MIMA")
- run_cmd([os.path.join(SPARK_HOME, "dev", "mima")])
+ run_cmd([os.path.join(SPARK_HOME, "dev", "mima")] + build_profiles)
def run_scala_tests_maven(test_profiles):
@@ -571,8 +572,8 @@ def main():
# backwards compatibility checks
if build_tool == "sbt":
# Note: compatibility tests only supported in sbt for now
- detect_binary_inop_with_mima()
- # Since we did not build assembly/assembly before running dev/mima, we need to
+ detect_binary_inop_with_mima(hadoop_version)
+ # Since we did not build assembly/package before running dev/mima, we need to
# do it here because the tests still rely on it; see SPARK-13294 for details.
build_spark_assembly_sbt(hadoop_version)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index bb04ec6ee6..c844bcff7e 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -256,9 +256,21 @@ streaming_flume_assembly = Module(
)
+mllib_local = Module(
+ name="mllib-local",
+ dependencies=[],
+ source_file_regexes=[
+ "mllib-local",
+ ],
+ sbt_test_goals=[
+ "mllib-local/test",
+ ]
+)
+
+
mllib = Module(
name="mllib",
- dependencies=[streaming, sql],
+ dependencies=[mllib_local, streaming, sql],
source_file_regexes=[
"data/mllib/",
"mllib/",
diff --git a/docs/building-spark.md b/docs/building-spark.md
index 1e202acb9e..40661604af 100644
--- a/docs/building-spark.md
+++ b/docs/building-spark.md
@@ -180,23 +180,16 @@ For help in setting up IntelliJ IDEA or Eclipse for Spark development, and troub
Running only Java 8 tests and nothing else.
- mvn install -DskipTests -Pjava8-tests
+ mvn install -DskipTests
+ mvn -pl :java8-tests_2.11 test
or
- sbt -Pjava8-tests java8-tests/test
+ sbt java8-tests/test
-Java 8 tests are run when `-Pjava8-tests` profile is enabled, they will run in spite of `-DskipTests`.
-For these tests to run your system must have a JDK 8 installation.
+Java 8 tests are automatically enabled when a Java 8 JDK is detected.
If you have JDK 8 installed but it is not the system default, you can set JAVA_HOME to point to JDK 8 before running the tests.
-# Building for PySpark on YARN
-
-PySpark on YARN is only supported if the jar is built with Maven. Further, there is a known problem
-with building this assembly jar on Red Hat based operating systems (see [SPARK-1753](https://issues.apache.org/jira/browse/SPARK-1753)). If you wish to
-run PySpark on a YARN cluster with Red Hat installed, we recommend that you build the jar elsewhere,
-then ship it over to the cluster. We are investigating the exact cause for this.
-
# Packaging without Hadoop Dependencies for YARN
The assembly jar produced by `mvn package` will, by default, include all of Spark's dependencies, including Hadoop and some of its ecosystem projects. On YARN deployments, this causes multiple versions of these to appear on executor classpaths: the version packaged in the Spark assembly and the version on each node, included with `yarn.application.classpath`. The `hadoop-provided` profile builds the assembly without including Hadoop-ecosystem projects, like ZooKeeper and Hadoop itself.
@@ -210,7 +203,7 @@ compilation. More advanced developers may wish to use SBT.
The SBT build is derived from the Maven POM files, and so the same Maven profiles and variables
can be set to control the SBT build. For example:
- build/sbt -Pyarn -Phadoop-2.3 assembly
+ build/sbt -Pyarn -Phadoop-2.3 package
To avoid the overhead of launching sbt each time you need to re-compile, you can launch sbt
in interactive mode by running `build/sbt`, and then run all build commands at the command
@@ -219,9 +212,9 @@ prompt. For more recommendations on reducing build time, refer to the
# Testing with SBT
-Some of the tests require Spark to be packaged first, so always run `build/sbt assembly` the first time. The following is an example of a correct (build, test) sequence:
+Some of the tests require Spark to be packaged first, so always run `build/sbt package` the first time. The following is an example of a correct (build, test) sequence:
- build/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver assembly
+ build/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver package
build/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver test
To run only a specific test suite as follows:
diff --git a/docs/configuration.md b/docs/configuration.md
index 937852ffde..16d5be62f9 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -225,11 +225,14 @@ Apart from these, the following properties are also available, and may be useful
<td>(none)</td>
<td>
A string of extra JVM options to pass to the driver. For instance, GC settings or other logging.
+ Note that it is illegal to set maximum heap size (-Xmx) settings with this option. Maximum heap
+ size settings can be set with <code>spark.driver.memory</code> in the cluster mode and through
+ the <code>--driver-memory</code> command line option in the client mode.
<br /><em>Note:</em> In client mode, this config must not be set through the <code>SparkConf</code>
directly in your application, because the driver JVM has already started at that point.
Instead, please set this through the <code>--driver-java-options</code> command line option or in
- your default properties file.</td>
+ your default properties file.
</td>
</tr>
<tr>
@@ -269,9 +272,9 @@ Apart from these, the following properties are also available, and may be useful
<td>(none)</td>
<td>
A string of extra JVM options to pass to executors. For instance, GC settings or other logging.
- Note that it is illegal to set Spark properties or heap size settings with this option. Spark
- properties should be set using a SparkConf object or the spark-defaults.conf file used with the
- spark-submit script. Heap size settings can be set with spark.executor.memory.
+ Note that it is illegal to set Spark properties or maximum heap size (-Xmx) settings with this
+ option. Spark properties should be set using a SparkConf object or the spark-defaults.conf file
+ used with the spark-submit script. Maximum heap size settings can be set with spark.executor.memory.
</td>
</tr>
<tr>
diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md
index 45155c8ad1..eaf4f6d843 100644
--- a/docs/ml-classification-regression.md
+++ b/docs/ml-classification-regression.md
@@ -302,6 +302,40 @@ Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/OneVsRe
</div>
</div>
+## Naive Bayes
+
+[Naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) are a family of simple
+probabilistic classifiers based on applying Bayes' theorem with strong (naive) independence
+assumptions between the features. The spark.ml implementation currently supports both [multinomial
+naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html)
+and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html).
+More information can be found in the section on [Naive Bayes in MLlib](mllib-naive-bayes.html#naive-bayes-sparkmllib).
+
+**Example**
+
+<div class="codetabs">
+<div data-lang="scala" markdown="1">
+
+Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.NaiveBayes) for more details.
+
+{% include_example scala/org/apache/spark/examples/ml/NaiveBayesExample.scala %}
+</div>
+
+<div data-lang="java" markdown="1">
+
+Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/NaiveBayes.html) for more details.
+
+{% include_example java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java %}
+</div>
+
+<div data-lang="python" markdown="1">
+
+Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.NaiveBayes) for more details.
+
+{% include_example python/ml/naive_bayes_example.py %}
+</div>
+</div>
+
# Regression
diff --git a/docs/ml-features.md b/docs/ml-features.md
index 4fe8eefc26..70812eb5e2 100644
--- a/docs/ml-features.md
+++ b/docs/ml-features.md
@@ -149,6 +149,15 @@ for more details on the API.
{% include_example java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java %}
</div>
+
+<div data-lang="python" markdown="1">
+
+Refer to the [CountVectorizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.CountVectorizer)
+and the [CountVectorizerModel Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.CountVectorizerModel)
+for more details on the API.
+
+{% include_example python/ml/count_vectorizer_example.py %}
+</div>
</div>
# Feature Transformers
@@ -413,6 +422,14 @@ for more details on the API.
{% include_example java/org/apache/spark/examples/ml/JavaDCTExample.java %}
</div>
+
+<div data-lang="python" markdown="1">
+
+Refer to the [DCT Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.DCT)
+for more details on the API.
+
+{% include_example python/ml/dct_example.py %}
+</div>
</div>
## StringIndexer
@@ -771,6 +788,14 @@ for more details on the API.
{% include_example java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java %}
</div>
+
+<div data-lang="python" markdown="1">
+
+Refer to the [MinMaxScaler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.MinMaxScaler)
+for more details on the API.
+
+{% include_example python/ml/min_max_scaler_example.py %}
+</div>
</div>
@@ -803,6 +828,14 @@ for more details on the API.
{% include_example java/org/apache/spark/examples/ml/JavaMaxAbsScalerExample.java %}
</div>
+
+<div data-lang="python" markdown="1">
+
+Refer to the [MaxAbsScaler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.MaxAbsScaler)
+for more details on the API.
+
+{% include_example python/ml/max_abs_scaler_example.py %}
+</div>
</div>
## Bucketizer
diff --git a/docs/monitoring.md b/docs/monitoring.md
index c139e1cb5a..32d2e02e93 100644
--- a/docs/monitoring.md
+++ b/docs/monitoring.md
@@ -8,7 +8,7 @@ There are several ways to monitor Spark applications: web UIs, metrics, and exte
# Web Interfaces
-Every SparkContext launches a web UI, by default on port 4040, that
+Every SparkContext launches a web UI, by default on port 4040, that
displays useful information about the application. This includes:
* A list of scheduler stages and tasks
@@ -32,19 +32,19 @@ Spark's Standalone Mode cluster manager also has its own
the course of its lifetime, then the Standalone master's web UI will automatically re-render the
application's UI after the application has finished.
-If Spark is run on Mesos or YARN, it is still possible to reconstruct the UI of a finished
+If Spark is run on Mesos or YARN, it is still possible to construct the UI of an
application through Spark's history server, provided that the application's event logs exist.
You can start the history server by executing:
./sbin/start-history-server.sh
This creates a web interface at `http://<server-url>:18080` by default, listing incomplete
-and completed applications and attempts, and allowing them to be viewed
+and completed applications and attempts.
When using the file-system provider class (see `spark.history.provider` below), the base logging
directory must be supplied in the `spark.history.fs.logDirectory` configuration option,
and should contain sub-directories that each represents an application's event logs.
-
+
The spark jobs themselves must be configured to log events, and to log them to the same shared,
writeable directory. For example, if the server was configured with a log directory of
`hdfs://namenode/shared/spark-logs`, then the client-side options would be:
@@ -53,7 +53,7 @@ writeable directory. For example, if the server was configured with a log direct
spark.eventLog.enabled true
spark.eventLog.dir hdfs://namenode/shared/spark-logs
```
-
+
The history server can be configured as follows:
### Environment Variables
@@ -135,9 +135,9 @@ The history server can be configured as follows:
<td>false</td>
<td>
Indicates whether the history server should use kerberos to login. This is required
- if the history server is accessing HDFS files on a secure Hadoop cluster. If this is
+ if the history server is accessing HDFS files on a secure Hadoop cluster. If this is
true, it uses the configs <code>spark.history.kerberos.principal</code> and
- <code>spark.history.kerberos.keytab</code>.
+ <code>spark.history.kerberos.keytab</code>.
</td>
</tr>
<tr>
@@ -159,12 +159,12 @@ The history server can be configured as follows:
<td>false</td>
<td>
Specifies whether acls should be checked to authorize users viewing the applications.
- If enabled, access control checks are made regardless of what the individual application had
+ If enabled, access control checks are made regardless of what the individual application had
set for <code>spark.ui.acls.enable</code> when the application was run. The application owner
- will always have authorization to view their own application and any users specified via
+ will always have authorization to view their own application and any users specified via
<code>spark.ui.view.acls</code> when the application was run will also have authorization
- to view that application.
- If disabled, no access control checks are made.
+ to view that application.
+ If disabled, no access control checks are made.
</td>
</tr>
<tr>
@@ -298,14 +298,14 @@ keep the paths consistent in both modes.
# Metrics
-Spark has a configurable metrics system based on the
-[Coda Hale Metrics Library](http://metrics.codahale.com/).
-This allows users to report Spark metrics to a variety of sinks including HTTP, JMX, and CSV
-files. The metrics system is configured via a configuration file that Spark expects to be present
-at `$SPARK_HOME/conf/metrics.properties`. A custom file location can be specified via the
+Spark has a configurable metrics system based on the
+[Coda Hale Metrics Library](http://metrics.codahale.com/).
+This allows users to report Spark metrics to a variety of sinks including HTTP, JMX, and CSV
+files. The metrics system is configured via a configuration file that Spark expects to be present
+at `$SPARK_HOME/conf/metrics.properties`. A custom file location can be specified via the
`spark.metrics.conf` [configuration property](configuration.html#spark-properties).
-Spark's metrics are decoupled into different
-_instances_ corresponding to Spark components. Within each instance, you can configure a
+Spark's metrics are decoupled into different
+_instances_ corresponding to Spark components. Within each instance, you can configure a
set of sinks to which metrics are reported. The following instances are currently supported:
* `master`: The Spark standalone master process.
@@ -330,26 +330,26 @@ licensing restrictions:
* `GangliaSink`: Sends metrics to a Ganglia node or multicast group.
To install the `GangliaSink` you'll need to perform a custom build of Spark. _**Note that
-by embedding this library you will include [LGPL](http://www.gnu.org/copyleft/lesser.html)-licensed
-code in your Spark package**_. For sbt users, set the
-`SPARK_GANGLIA_LGPL` environment variable before building. For Maven users, enable
+by embedding this library you will include [LGPL](http://www.gnu.org/copyleft/lesser.html)-licensed
+code in your Spark package**_. For sbt users, set the
+`SPARK_GANGLIA_LGPL` environment variable before building. For Maven users, enable
the `-Pspark-ganglia-lgpl` profile. In addition to modifying the cluster's Spark build
user applications will need to link to the `spark-ganglia-lgpl` artifact.
-The syntax of the metrics configuration file is defined in an example configuration file,
+The syntax of the metrics configuration file is defined in an example configuration file,
`$SPARK_HOME/conf/metrics.properties.template`.
# Advanced Instrumentation
Several external tools can be used to help profile the performance of Spark jobs:
-* Cluster-wide monitoring tools, such as [Ganglia](http://ganglia.sourceforge.net/), can provide
-insight into overall cluster utilization and resource bottlenecks. For instance, a Ganglia
-dashboard can quickly reveal whether a particular workload is disk bound, network bound, or
+* Cluster-wide monitoring tools, such as [Ganglia](http://ganglia.sourceforge.net/), can provide
+insight into overall cluster utilization and resource bottlenecks. For instance, a Ganglia
+dashboard can quickly reveal whether a particular workload is disk bound, network bound, or
CPU bound.
-* OS profiling tools such as [dstat](http://dag.wieers.com/home-made/dstat/),
-[iostat](http://linux.die.net/man/1/iostat), and [iotop](http://linux.die.net/man/1/iotop)
+* OS profiling tools such as [dstat](http://dag.wieers.com/home-made/dstat/),
+[iostat](http://linux.die.net/man/1/iostat), and [iotop](http://linux.die.net/man/1/iotop)
can provide fine-grained profiling on individual nodes.
-* JVM utilities such as `jstack` for providing stack traces, `jmap` for creating heap-dumps,
-`jstat` for reporting time-series statistics and `jconsole` for visually exploring various JVM
+* JVM utilities such as `jstack` for providing stack traces, `jmap` for creating heap-dumps,
+`jstat` for reporting time-series statistics and `jconsole` for visually exploring various JVM
properties are useful for those comfortable with JVM internals.
diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md
index 293a82882e..4a0ab623c1 100644
--- a/docs/running-on-mesos.md
+++ b/docs/running-on-mesos.md
@@ -108,7 +108,7 @@ the `dev/make-distribution.sh` script included in a Spark source tarball/checkou
## Using a Mesos Master URL
The Master URLs for Mesos are in the form `mesos://host:5050` for a single-master Mesos
-cluster, or `mesos://zk://host:2181` for a multi-master Mesos cluster using ZooKeeper.
+cluster, or `mesos://zk://host1:2181,host2:2181,host3:2181/mesos` for a multi-master Mesos cluster using ZooKeeper.
## Client Mode
@@ -215,10 +215,10 @@ conf.set("spark.mesos.coarse", "false")
You may also make use of `spark.mesos.constraints` to set attribute based constraints on mesos resource offers. By default, all resource offers will be accepted.
{% highlight scala %}
-conf.set("spark.mesos.constraints", "tachyon:true;us-east-1:false")
+conf.set("spark.mesos.constraints", "os:centos7;us-east-1:false")
{% endhighlight %}
-For example, Let's say `spark.mesos.constraints` is set to `tachyon:true;us-east-1:false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors.
+For example, Let's say `spark.mesos.constraints` is set to `os:centos7;us-east-1:false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors.
# Mesos Docker Support
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index c775fe710f..09701abdb0 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -160,6 +160,13 @@ If you need a reference to the proper location to put log files in the YARN so t
</td>
</tr>
<tr>
+ <td><code>spark.yarn.stagingDir</code></td>
+ <td>Current user's home directory in the filesystem</td>
+ <td>
+ Staging directory used while submitting applications.
+ </td>
+</tr>
+<tr>
<td><code>spark.yarn.preserve.staging.files</code></td>
<td><code>false</code></td>
<td>
@@ -216,6 +223,13 @@ If you need a reference to the proper location to put log files in the YARN so t
</td>
</tr>
<tr>
+ <td><code>spark.yarn.dist.jars</code></td>
+ <td>(none)</td>
+ <td>
+ Comma-separated list of jars to be placed in the working directory of each executor.
+ </td>
+</tr>
+<tr>
<td><code>spark.executor.cores</code></td>
<td>1 in YARN mode, all the available cores on the worker in standalone mode.</td>
<td>
@@ -328,7 +342,9 @@ If you need a reference to the proper location to put log files in the YARN so t
<td>(none)</td>
<td>
A string of extra JVM options to pass to the YARN Application Master in client mode.
- In cluster mode, use <code>spark.driver.extraJavaOptions</code> instead.
+ In cluster mode, use <code>spark.driver.extraJavaOptions</code> instead. Note that it is illegal
+ to set maximum heap size (-Xmx) settings with this option. Maximum heap size settings can be set
+ with <code>spark.yarn.am.memory</code>
</td>
</tr>
<tr>
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 2fdc97f8a0..2d9849d032 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -1467,37 +1467,6 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`
</td>
</tr>
<tr>
- <td><code>spark.sql.parquet.output.committer.class</code></td>
- <td><code>org.apache.parquet.hadoop.<br />ParquetOutputCommitter</code></td>
- <td>
- <p>
- The output committer class used by Parquet. The specified class needs to be a subclass of
- <code>org.apache.hadoop.<br />mapreduce.OutputCommitter</code>. Typically, it's also a
- subclass of <code>org.apache.parquet.hadoop.ParquetOutputCommitter</code>.
- </p>
- <p>
- <b>Note:</b>
- <ul>
- <li>
- This option is automatically ignored if <code>spark.speculation</code> is turned on.
- </li>
- <li>
- This option must be set via Hadoop <code>Configuration</code> rather than Spark
- <code>SQLConf</code>.
- </li>
- <li>
- This option overrides <code>spark.sql.sources.<br />outputCommitterClass</code>.
- </li>
- </ul>
- </p>
- <p>
- Spark SQL comes with a builtin
- <code>org.apache.spark.sql.<br />parquet.DirectParquetOutputCommitter</code>, which can be more
- efficient then the default Parquet output committer when writing data to S3.
- </p>
- </td>
-</tr>
-<tr>
<td><code>spark.sql.parquet.mergeSchema</code></td>
<td><code>false</code></td>
<td>
@@ -1533,7 +1502,7 @@ val people = sqlContext.read.json(path)
// The inferred schema can be visualized using the printSchema() method.
people.printSchema()
// root
-// |-- age: integer (nullable = true)
+// |-- age: long (nullable = true)
// |-- name: string (nullable = true)
// Register this DataFrame as a table.
@@ -1571,7 +1540,7 @@ DataFrame people = sqlContext.read().json("examples/src/main/resources/people.js
// The inferred schema can be visualized using the printSchema() method.
people.printSchema();
// root
-// |-- age: integer (nullable = true)
+// |-- age: long (nullable = true)
// |-- name: string (nullable = true)
// Register this DataFrame as a table.
@@ -1609,7 +1578,7 @@ people = sqlContext.read.json("examples/src/main/resources/people.json")
# The inferred schema can be visualized using the printSchema() method.
people.printSchema()
# root
-# |-- age: integer (nullable = true)
+# |-- age: long (nullable = true)
# |-- name: string (nullable = true)
# Register this DataFrame as a table.
@@ -1648,7 +1617,7 @@ people <- jsonFile(sqlContext, path)
# The inferred schema can be visualized using the printSchema() method.
printSchema(people)
# root
-# |-- age: integer (nullable = true)
+# |-- age: long (nullable = true)
# |-- name: string (nullable = true)
# Register this DataFrame as a table.
@@ -1687,12 +1656,7 @@ on all of the worker nodes, as they will need access to the Hive serialization a
(SerDes) in order to access data stored in Hive.
Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` (for security configuration),
- `hdfs-site.xml` (for HDFS configuration) file in `conf/`. Please note when running
-the query on a YARN cluster (`cluster` mode), the `datanucleus` jars under the `lib` directory
-and `hive-site.xml` under `conf/` directory need to be available on the driver and all executors launched by the
-YARN cluster. The convenient way to do this is adding them through the `--jars` option and `--file` option of the
-`spark-submit` command.
-
+`hdfs-site.xml` (for HDFS configuration) file in `conf/`.
<div class="codetabs">
@@ -2170,8 +2134,6 @@ options.
- In the `sql` dialect, floating point numbers are now parsed as decimal. HiveQL parsing remains
unchanged.
- The canonical name of SQL/DataFrame functions are now lower case (e.g. sum vs SUM).
- - It has been determined that using the DirectOutputCommitter when speculation is enabled is unsafe
- and thus this output committer will not be used when speculation is on, independent of configuration.
- JSON data source will not automatically load new files that are created by other applications
(i.e. files that are not inserted to the dataset through Spark SQL).
For a JSON persistent table (i.e. the metadata of the table is stored in Hive Metastore),
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index 8d21917a7d..7f6c0ed699 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -2178,7 +2178,7 @@ overall processing throughput of the system, its use is still recommended to ach
consistent batch processing times. Make sure you set the CMS GC on both the driver (using `--driver-java-options` in `spark-submit`) and the executors (using [Spark configuration](configuration.html#runtime-environment) `spark.executor.extraJavaOptions`).
* **Other tips**: To further reduce GC overheads, here are some more tips to try.
- - Use Tachyon for off-heap storage of persisted RDDs. See more detail in the [Spark Programming Guide](programming-guide.html#rdd-persistence).
+ - Persist RDDs using the `OFF_HEAP` storage level. See more detail in the [Spark Programming Guide](programming-guide.html#rdd-persistence).
- Use more executors with smaller heap sizes. This will reduce the GC pressure within each JVM heap.
diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md
index 66025ed6ba..100ff0b147 100644
--- a/docs/submitting-applications.md
+++ b/docs/submitting-applications.md
@@ -58,8 +58,7 @@ for applications that involve the REPL (e.g. Spark shell).
Alternatively, if your application is submitted from a machine far from the worker machines (e.g.
locally on your laptop), it is common to use `cluster` mode to minimize network latency between
-the drivers and the executors. Note that `cluster` mode is currently not supported for
-Mesos clusters. Currently only YARN supports cluster mode for Python applications.
+the drivers and the executors. Currently only YARN supports cluster mode for Python applications.
For Python applications, simply pass a `.py` file in the place of `<application-jar>` instead of a JAR,
and add Python `.zip`, `.egg` or `.py` files to the search path with `--py-files`.
diff --git a/examples/pom.xml b/examples/pom.xml
index b7f37978b9..4a20370f06 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -27,13 +27,16 @@
<groupId>org.apache.spark</groupId>
<artifactId>spark-examples_2.11</artifactId>
- <properties>
- <sbt.project.name>examples</sbt.project.name>
- </properties>
<packaging>jar</packaging>
<name>Spark Project Examples</name>
<url>http://spark.apache.org/</url>
+ <properties>
+ <sbt.project.name>examples</sbt.project.name>
+ <build.testJarPhase>none</build.testJarPhase>
+ <build.copyDependenciesPhase>package</build.copyDependenciesPhase>
+ </properties>
+
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
@@ -77,23 +80,6 @@
</dependency>
<dependency>
<groupId>org.apache.hbase</groupId>
- <artifactId>hbase-testing-util</artifactId>
- <version>${hbase.version}</version>
- <scope>${hbase.deps.scope}</scope>
- <exclusions>
- <exclusion>
- <!-- SPARK-4455 -->
- <groupId>org.apache.hbase</groupId>
- <artifactId>hbase-annotations</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.jruby</groupId>
- <artifactId>jruby-complete</artifactId>
- </exclusion>
- </exclusions>
- </dependency>
- <dependency>
- <groupId>org.apache.hbase</groupId>
<artifactId>hbase-protocol</artifactId>
<version>${hbase.version}</version>
<scope>${hbase.deps.scope}</scope>
@@ -140,6 +126,10 @@
<artifactId>hbase-annotations</artifactId>
</exclusion>
<exclusion>
+ <groupId>org.apache.hbase</groupId>
+ <artifactId>hbase-common</artifactId>
+ </exclusion>
+ <exclusion>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-core</artifactId>
</exclusion>
@@ -209,13 +199,6 @@
<scope>${hbase.deps.scope}</scope>
</dependency>
<dependency>
- <groupId>org.apache.hbase</groupId>
- <artifactId>hbase-hadoop-compat</artifactId>
- <version>${hbase.version}</version>
- <type>test-jar</type>
- <scope>test</scope>
- </dependency>
- <dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<scope>provided</scope>
@@ -294,17 +277,6 @@
<artifactId>scopt_${scala.binary.version}</artifactId>
<version>3.3.0</version>
</dependency>
-
- <!--
- The following dependencies are already present in the Spark assembly, so we want to force
- them to be provided.
- -->
- <dependency>
- <groupId>org.scala-lang</groupId>
- <artifactId>scala-library</artifactId>
- <scope>provided</scope>
- </dependency>
-
</dependencies>
<build>
@@ -325,38 +297,6 @@
<skip>true</skip>
</configuration>
</plugin>
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-jar-plugin</artifactId>
- <executions>
- <execution>
- <id>prepare-test-jar</id>
- <phase>none</phase>
- <goals>
- <goal>test-jar</goal>
- </goals>
- </execution>
- </executions>
- <configuration>
- <outputDirectory>${jars.target.dir}</outputDirectory>
- </configuration>
- </plugin>
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-dependency-plugin</artifactId>
- <executions>
- <execution>
- <phase>package</phase>
- <goals>
- <goal>copy-dependencies</goal>
- </goals>
- <configuration>
- <includeScope>runtime</includeScope>
- <outputDirectory>${jars.target.dir}</outputDirectory>
- </configuration>
- </execution>
- </executions>
- </plugin>
</plugins>
</build>
<profiles>
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java
index 8abc03e73d..ebb0687b14 100644
--- a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java
+++ b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java
@@ -82,10 +82,10 @@ public final class JavaLogQuery {
String user = m.group(3);
String query = m.group(5);
if (!user.equalsIgnoreCase("-")) {
- return new Tuple3<String, String, String>(ip, user, query);
+ return new Tuple3<>(ip, user, query);
}
}
- return new Tuple3<String, String, String>(null, null, null);
+ return new Tuple3<>(null, null, null);
}
public static Stats extractStats(String line) {
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
deleted file mode 100644
index 07edeb3e52..0000000000
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
+++ /dev/null
@@ -1,127 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.examples.ml;
-
-import java.util.List;
-
-import com.google.common.collect.Lists;
-
-import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.Pipeline;
-import org.apache.spark.ml.PipelineStage;
-import org.apache.spark.ml.classification.LogisticRegression;
-import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
-import org.apache.spark.ml.feature.HashingTF;
-import org.apache.spark.ml.feature.Tokenizer;
-import org.apache.spark.ml.param.ParamMap;
-import org.apache.spark.ml.tuning.CrossValidator;
-import org.apache.spark.ml.tuning.CrossValidatorModel;
-import org.apache.spark.ml.tuning.ParamGridBuilder;
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
-
-/**
- * A simple example demonstrating model selection using CrossValidator.
- * This example also demonstrates how Pipelines are Estimators.
- *
- * This example uses the Java bean classes {@link org.apache.spark.examples.ml.LabeledDocument} and
- * {@link org.apache.spark.examples.ml.Document} defined in the Scala example
- * {@link org.apache.spark.examples.ml.SimpleTextClassificationPipeline}.
- *
- * Run with
- * <pre>
- * bin/run-example ml.JavaCrossValidatorExample
- * </pre>
- */
-public class JavaCrossValidatorExample {
-
- public static void main(String[] args) {
- SparkConf conf = new SparkConf().setAppName("JavaCrossValidatorExample");
- JavaSparkContext jsc = new JavaSparkContext(conf);
- SQLContext jsql = new SQLContext(jsc);
-
- // Prepare training documents, which are labeled.
- List<LabeledDocument> localTraining = Lists.newArrayList(
- new LabeledDocument(0L, "a b c d e spark", 1.0),
- new LabeledDocument(1L, "b d", 0.0),
- new LabeledDocument(2L, "spark f g h", 1.0),
- new LabeledDocument(3L, "hadoop mapreduce", 0.0),
- new LabeledDocument(4L, "b spark who", 1.0),
- new LabeledDocument(5L, "g d a y", 0.0),
- new LabeledDocument(6L, "spark fly", 1.0),
- new LabeledDocument(7L, "was mapreduce", 0.0),
- new LabeledDocument(8L, "e spark program", 1.0),
- new LabeledDocument(9L, "a e c l", 0.0),
- new LabeledDocument(10L, "spark compile", 1.0),
- new LabeledDocument(11L, "hadoop software", 0.0));
- Dataset<Row> training = jsql.createDataFrame(
- jsc.parallelize(localTraining), LabeledDocument.class);
-
- // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
- Tokenizer tokenizer = new Tokenizer()
- .setInputCol("text")
- .setOutputCol("words");
- HashingTF hashingTF = new HashingTF()
- .setNumFeatures(1000)
- .setInputCol(tokenizer.getOutputCol())
- .setOutputCol("features");
- LogisticRegression lr = new LogisticRegression()
- .setMaxIter(10)
- .setRegParam(0.01);
- Pipeline pipeline = new Pipeline()
- .setStages(new PipelineStage[] {tokenizer, hashingTF, lr});
-
- // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
- // This will allow us to jointly choose parameters for all Pipeline stages.
- // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
- CrossValidator crossval = new CrossValidator()
- .setEstimator(pipeline)
- .setEvaluator(new BinaryClassificationEvaluator());
- // We use a ParamGridBuilder to construct a grid of parameters to search over.
- // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
- // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
- ParamMap[] paramGrid = new ParamGridBuilder()
- .addGrid(hashingTF.numFeatures(), new int[]{10, 100, 1000})
- .addGrid(lr.regParam(), new double[]{0.1, 0.01})
- .build();
- crossval.setEstimatorParamMaps(paramGrid);
- crossval.setNumFolds(2); // Use 3+ in practice
-
- // Run cross-validation, and choose the best set of parameters.
- CrossValidatorModel cvModel = crossval.fit(training);
-
- // Prepare test documents, which are unlabeled.
- List<Document> localTest = Lists.newArrayList(
- new Document(4L, "spark i j k"),
- new Document(5L, "l m n"),
- new Document(6L, "mapreduce spark"),
- new Document(7L, "apache hadoop"));
- Dataset<Row> test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class);
-
- // Make predictions on test documents. cvModel uses the best model found (lrModel).
- Dataset<Row> predictions = cvModel.transform(test);
- for (Row r: predictions.select("id", "text", "probability", "prediction").collectAsList()) {
- System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
- + ", prediction=" + r.get(3));
- }
-
- jsc.stop();
- }
-}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
index fbd8817669..0ba94786d4 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
@@ -146,7 +146,7 @@ class MyJavaLogisticRegression
// This method is used by fit().
// In Java, we have to make it public since Java does not understand Scala's protected modifier.
- public MyJavaLogisticRegressionModel train(Dataset<Row> dataset) {
+ public MyJavaLogisticRegressionModel train(Dataset<?> dataset) {
// Extract columns from data using helper method.
JavaRDD<LabeledPoint> oldDataset = extractLabeledPoints(dataset).toJavaRDD();
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java
index 6ac4aea3c4..4994f8f9fa 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java
@@ -32,7 +32,15 @@ import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
/**
- * Java example for Model Selection via Train Validation Split.
+ * Java example demonstrating model selection using TrainValidationSplit.
+ *
+ * The example is based on {@link org.apache.spark.examples.ml.JavaSimpleParamsExample}
+ * using linear regression.
+ *
+ * Run with
+ * {{{
+ * bin/run-example ml.JavaModelSelectionViaTrainValidationSplitExample
+ * }}}
*/
public class JavaModelSelectionViaTrainValidationSplitExample {
public static void main(String[] args) {
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java
new file mode 100644
index 0000000000..41d7ad75b9
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaSparkContext;
+// $example on$
+import org.apache.spark.ml.classification.NaiveBayes;
+import org.apache.spark.ml.classification.NaiveBayesModel;
+import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
+// $example off$
+
+/**
+ * An example for Naive Bayes Classification.
+ */
+public class JavaNaiveBayesExample {
+
+ public static void main(String[] args) {
+ SparkConf conf = new SparkConf().setAppName("JavaNaiveBayesExample");
+ JavaSparkContext jsc = new JavaSparkContext(conf);
+ SQLContext jsql = new SQLContext(jsc);
+
+ // $example on$
+ // Load training data
+ Dataset<Row> dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
+ // Split the data into train and test
+ Dataset<Row>[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L);
+ Dataset<Row> train = splits[0];
+ Dataset<Row> test = splits[1];
+
+ // create the trainer and set its parameters
+ NaiveBayes nb = new NaiveBayes();
+ // train the model
+ NaiveBayesModel model = nb.fit(train);
+ // compute precision on the test set
+ Dataset<Row> result = model.transform(test);
+ Dataset<Row> predictionAndLabels = result.select("prediction", "label");
+ MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
+ .setMetricName("precision");
+ System.out.println("Precision = " + evaluator.evaluate(predictionAndLabels));
+ // $example off$
+
+ jsc.stop();
+ }
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java
deleted file mode 100644
index 09bbc39c01..0000000000
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java
+++ /dev/null
@@ -1,87 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.examples.ml;
-
-import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.evaluation.RegressionEvaluator;
-import org.apache.spark.ml.param.ParamMap;
-import org.apache.spark.ml.regression.LinearRegression;
-import org.apache.spark.ml.tuning.*;
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
-
-/**
- * A simple example demonstrating model selection using TrainValidationSplit.
- *
- * The example is based on {@link org.apache.spark.examples.ml.JavaSimpleParamsExample}
- * using linear regression.
- *
- * Run with
- * {{{
- * bin/run-example ml.JavaTrainValidationSplitExample
- * }}}
- */
-public class JavaTrainValidationSplitExample {
-
- public static void main(String[] args) {
- SparkConf conf = new SparkConf().setAppName("JavaTrainValidationSplitExample");
- JavaSparkContext jsc = new JavaSparkContext(conf);
- SQLContext jsql = new SQLContext(jsc);
-
- Dataset<Row> data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
-
- // Prepare training and test data.
- Dataset<Row>[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345);
- Dataset<Row> training = splits[0];
- Dataset<Row> test = splits[1];
-
- LinearRegression lr = new LinearRegression();
-
- // We use a ParamGridBuilder to construct a grid of parameters to search over.
- // TrainValidationSplit will try all combinations of values and determine best model using
- // the evaluator.
- ParamMap[] paramGrid = new ParamGridBuilder()
- .addGrid(lr.regParam(), new double[] {0.1, 0.01})
- .addGrid(lr.fitIntercept())
- .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0})
- .build();
-
- // In this case the estimator is simply the linear regression.
- // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
- TrainValidationSplit trainValidationSplit = new TrainValidationSplit()
- .setEstimator(lr)
- .setEvaluator(new RegressionEvaluator())
- .setEstimatorParamMaps(paramGrid);
-
- // 80% of the data will be used for training and the remaining 20% for validation.
- trainValidationSplit.setTrainRatio(0.8);
-
- // Run train validation split, and choose the best set of parameters.
- TrainValidationSplitModel model = trainValidationSplit.fit(training);
-
- // Make predictions on test data. model is the model with combination of parameters
- // that performed best.
- model.transform(test)
- .select("features", "label", "prediction")
- .show();
-
- jsc.stop();
- }
-}
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java
deleted file mode 100644
index 36baf58687..0000000000
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.examples.mllib;
-
-import java.util.ArrayList;
-
-import com.google.common.base.Joiner;
-import com.google.common.collect.Lists;
-
-import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.mllib.fpm.FPGrowth;
-import org.apache.spark.mllib.fpm.FPGrowthModel;
-
-/**
- * Java example for mining frequent itemsets using FP-growth.
- * Example usage: ./bin/run-example mllib.JavaFPGrowthExample ./data/mllib/sample_fpgrowth.txt
- */
-public class JavaFPGrowthExample {
-
- public static void main(String[] args) {
- String inputFile;
- double minSupport = 0.3;
- int numPartition = -1;
- if (args.length < 1) {
- System.err.println(
- "Usage: JavaFPGrowth <input_file> [minSupport] [numPartition]");
- System.exit(1);
- }
- inputFile = args[0];
- if (args.length >= 2) {
- minSupport = Double.parseDouble(args[1]);
- }
- if (args.length >= 3) {
- numPartition = Integer.parseInt(args[2]);
- }
-
- SparkConf sparkConf = new SparkConf().setAppName("JavaFPGrowthExample");
- JavaSparkContext sc = new JavaSparkContext(sparkConf);
-
- JavaRDD<ArrayList<String>> transactions = sc.textFile(inputFile).map(
- new Function<String, ArrayList<String>>() {
- @Override
- public ArrayList<String> call(String s) {
- return Lists.newArrayList(s.split(" "));
- }
- }
- );
-
- FPGrowthModel<String> model = new FPGrowth()
- .setMinSupport(minSupport)
- .setNumPartitions(numPartition)
- .run(transactions);
-
- for (FPGrowth.FreqItemset<String> s: model.freqItemsets().toJavaRDD().collect()) {
- System.out.println("[" + Joiner.on(",").join(s.javaItems()) + "], " + s.freq());
- }
-
- sc.stop();
- }
-}
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java
deleted file mode 100644
index e575eedeb4..0000000000
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java
+++ /dev/null
@@ -1,82 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.examples.mllib;
-
-import java.util.regex.Pattern;
-
-import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.api.java.function.Function;
-
-import org.apache.spark.mllib.clustering.KMeans;
-import org.apache.spark.mllib.clustering.KMeansModel;
-import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.mllib.linalg.Vectors;
-
-/**
- * Example using MLlib KMeans from Java.
- */
-public final class JavaKMeans {
-
- private static class ParsePoint implements Function<String, Vector> {
- private static final Pattern SPACE = Pattern.compile(" ");
-
- @Override
- public Vector call(String line) {
- String[] tok = SPACE.split(line);
- double[] point = new double[tok.length];
- for (int i = 0; i < tok.length; ++i) {
- point[i] = Double.parseDouble(tok[i]);
- }
- return Vectors.dense(point);
- }
- }
-
- public static void main(String[] args) {
- if (args.length < 3) {
- System.err.println(
- "Usage: JavaKMeans <input_file> <k> <max_iterations> [<runs>]");
- System.exit(1);
- }
- String inputFile = args[0];
- int k = Integer.parseInt(args[1]);
- int iterations = Integer.parseInt(args[2]);
- int runs = 1;
-
- if (args.length >= 4) {
- runs = Integer.parseInt(args[3]);
- }
- SparkConf sparkConf = new SparkConf().setAppName("JavaKMeans");
- JavaSparkContext sc = new JavaSparkContext(sparkConf);
- JavaRDD<String> lines = sc.textFile(inputFile);
-
- JavaRDD<Vector> points = lines.map(new ParsePoint());
-
- KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs, KMeans.K_MEANS_PARALLEL());
-
- System.out.println("Cluster centers:");
- for (Vector center : model.clusterCenters()) {
- System.out.println(" " + center);
- }
- double cost = model.computeCost(points.rdd());
- System.out.println("Cost: " + cost);
-
- sc.stop();
- }
-}
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeansExample.java
index 006d96d111..2d89c768fc 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeansExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeansExample.java
@@ -58,6 +58,13 @@ public class JavaKMeansExample {
int numIterations = 20;
KMeansModel clusters = KMeans.train(parsedData.rdd(), numClusters, numIterations);
+ System.out.println("Cluster centers:");
+ for (Vector center: clusters.clusterCenters()) {
+ System.out.println(" " + center);
+ }
+ double cost = clusters.computeCost(parsedData.rdd());
+ System.out.println("Cost: " + cost);
+
// Evaluate clustering by computing Within Set Sum of Squared Errors
double WSSSE = clusters.computeCost(parsedData.rdd());
System.out.println("Within Set Sum of Squared Errors = " + WSSSE);
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java
deleted file mode 100644
index de8e739ac9..0000000000
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java
+++ /dev/null
@@ -1,77 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.examples.mllib;
-
-import scala.Tuple2;
-
-import org.apache.spark.api.java.*;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.mllib.clustering.DistributedLDAModel;
-import org.apache.spark.mllib.clustering.LDA;
-import org.apache.spark.mllib.linalg.Matrix;
-import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.SparkConf;
-
-public class JavaLDAExample {
- public static void main(String[] args) {
- SparkConf conf = new SparkConf().setAppName("LDA Example");
- JavaSparkContext sc = new JavaSparkContext(conf);
-
- // Load and parse the data
- String path = "data/mllib/sample_lda_data.txt";
- JavaRDD<String> data = sc.textFile(path);
- JavaRDD<Vector> parsedData = data.map(
- new Function<String, Vector>() {
- public Vector call(String s) {
- String[] sarray = s.trim().split(" ");
- double[] values = new double[sarray.length];
- for (int i = 0; i < sarray.length; i++) {
- values[i] = Double.parseDouble(sarray[i]);
- }
- return Vectors.dense(values);
- }
- }
- );
- // Index documents with unique IDs
- JavaPairRDD<Long, Vector> corpus = JavaPairRDD.fromJavaRDD(parsedData.zipWithIndex().map(
- new Function<Tuple2<Vector, Long>, Tuple2<Long, Vector>>() {
- public Tuple2<Long, Vector> call(Tuple2<Vector, Long> doc_id) {
- return doc_id.swap();
- }
- }
- ));
- corpus.cache();
-
- // Cluster the documents into three topics using LDA
- DistributedLDAModel ldaModel = (DistributedLDAModel)new LDA().setK(3).run(corpus);
-
- // Output topics. Each is a distribution over words (matching word count vectors)
- System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize()
- + " words):");
- Matrix topics = ldaModel.topicsMatrix();
- for (int topic = 0; topic < 3; topic++) {
- System.out.print("Topic " + topic + ":");
- for (int word = 0; word < ldaModel.vocabSize(); word++) {
- System.out.print(" " + topics.apply(word, topic));
- }
- System.out.println();
- }
- sc.stop();
- }
-}
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLR.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLR.java
deleted file mode 100644
index eceb6927d5..0000000000
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLR.java
+++ /dev/null
@@ -1,82 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.examples.mllib;
-
-import java.util.regex.Pattern;
-
-import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.api.java.function.Function;
-
-import org.apache.spark.mllib.classification.LogisticRegressionWithSGD;
-import org.apache.spark.mllib.classification.LogisticRegressionModel;
-import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.mllib.regression.LabeledPoint;
-
-/**
- * Logistic regression based classification using ML Lib.
- */
-public final class JavaLR {
-
- static class ParsePoint implements Function<String, LabeledPoint> {
- private static final Pattern COMMA = Pattern.compile(",");
- private static final Pattern SPACE = Pattern.compile(" ");
-
- @Override
- public LabeledPoint call(String line) {
- String[] parts = COMMA.split(line);
- double y = Double.parseDouble(parts[0]);
- String[] tok = SPACE.split(parts[1]);
- double[] x = new double[tok.length];
- for (int i = 0; i < tok.length; ++i) {
- x[i] = Double.parseDouble(tok[i]);
- }
- return new LabeledPoint(y, Vectors.dense(x));
- }
- }
-
- public static void main(String[] args) {
- if (args.length != 3) {
- System.err.println("Usage: JavaLR <input_dir> <step_size> <niters>");
- System.exit(1);
- }
- SparkConf sparkConf = new SparkConf().setAppName("JavaLR");
- JavaSparkContext sc = new JavaSparkContext(sparkConf);
- JavaRDD<String> lines = sc.textFile(args[0]);
- JavaRDD<LabeledPoint> points = lines.map(new ParsePoint()).cache();
- double stepSize = Double.parseDouble(args[1]);
- int iterations = Integer.parseInt(args[2]);
-
- // Another way to configure LogisticRegression
- //
- // LogisticRegressionWithSGD lr = new LogisticRegressionWithSGD();
- // lr.optimizer().setNumIterations(iterations)
- // .setStepSize(stepSize)
- // .setMiniBatchFraction(1.0);
- // lr.setIntercept(true);
- // LogisticRegressionModel model = lr.train(points.rdd());
-
- LogisticRegressionModel model = LogisticRegressionWithSGD.train(points.rdd(),
- iterations, stepSize);
-
- System.out.print("Final w: " + model.weights());
-
- sc.stop();
- }
-}
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java
index 5904260e2d..bc99dc023f 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java
@@ -34,13 +34,13 @@ public class JavaMultiLabelClassificationMetricsExample {
JavaSparkContext sc = new JavaSparkContext(conf);
// $example on$
List<Tuple2<double[], double[]>> data = Arrays.asList(
- new Tuple2<double[], double[]>(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}),
- new Tuple2<double[], double[]>(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}),
- new Tuple2<double[], double[]>(new double[]{}, new double[]{0.0}),
- new Tuple2<double[], double[]>(new double[]{2.0}, new double[]{2.0}),
- new Tuple2<double[], double[]>(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}),
- new Tuple2<double[], double[]>(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}),
- new Tuple2<double[], double[]>(new double[]{1.0}, new double[]{1.0, 2.0})
+ new Tuple2<>(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}),
+ new Tuple2<>(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}),
+ new Tuple2<>(new double[]{}, new double[]{0.0}),
+ new Tuple2<>(new double[]{2.0}, new double[]{2.0}),
+ new Tuple2<>(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}),
+ new Tuple2<>(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}),
+ new Tuple2<>(new double[]{1.0}, new double[]{1.0, 2.0})
);
JavaRDD<Tuple2<double[], double[]>> scoreAndLabels = sc.parallelize(data);
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java
index b62fa90c34..91c3bd72da 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java
@@ -40,11 +40,11 @@ public class JavaPowerIterationClusteringExample {
@SuppressWarnings("unchecked")
// $example on$
JavaRDD<Tuple3<Long, Long, Double>> similarities = sc.parallelize(Lists.newArrayList(
- new Tuple3<Long, Long, Double>(0L, 1L, 0.9),
- new Tuple3<Long, Long, Double>(1L, 2L, 0.9),
- new Tuple3<Long, Long, Double>(2L, 3L, 0.9),
- new Tuple3<Long, Long, Double>(3L, 4L, 0.1),
- new Tuple3<Long, Long, Double>(4L, 5L, 0.9)));
+ new Tuple3<>(0L, 1L, 0.9),
+ new Tuple3<>(1L, 2L, 0.9),
+ new Tuple3<>(2L, 3L, 0.9),
+ new Tuple3<>(3L, 4L, 0.1),
+ new Tuple3<>(4L, 5L, 0.9)));
PowerIterationClustering pic = new PowerIterationClustering()
.setK(2)
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java
index c27fba2783..72bbb2a8fa 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java
@@ -35,8 +35,9 @@ public class JavaStratifiedSamplingExample {
SparkConf conf = new SparkConf().setAppName("JavaStratifiedSamplingExample");
JavaSparkContext jsc = new JavaSparkContext(conf);
+ @SuppressWarnings("unchecked")
// $example on$
- List<Tuple2<Integer, Character>> list = new ArrayList<Tuple2<Integer, Character>>(
+ List<Tuple2<Integer, Character>> list = new ArrayList<>(
Arrays.<Tuple2<Integer, Character>>asList(
new Tuple2(1, 'a'),
new Tuple2(1, 'b'),
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java
index da56637fe8..bae4b78ac2 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java
@@ -19,7 +19,6 @@ package org.apache.spark.examples.streaming;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.function.Function;
-import org.apache.spark.examples.streaming.StreamingExamples;
import org.apache.spark.streaming.*;
import org.apache.spark.streaming.api.java.*;
import org.apache.spark.streaming.flume.FlumeUtils;
@@ -58,7 +57,8 @@ public final class JavaFlumeEventCount {
Duration batchInterval = new Duration(2000);
SparkConf sparkConf = new SparkConf().setAppName("JavaFlumeEventCount");
JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, batchInterval);
- JavaReceiverInputDStream<SparkFlumeEvent> flumeStream = FlumeUtils.createStream(ssc, host, port);
+ JavaReceiverInputDStream<SparkFlumeEvent> flumeStream =
+ FlumeUtils.createStream(ssc, host, port);
flumeStream.count();
diff --git a/examples/src/main/python/ml/count_vectorizer_example.py b/examples/src/main/python/ml/count_vectorizer_example.py
new file mode 100644
index 0000000000..e839f645f7
--- /dev/null
+++ b/examples/src/main/python/ml/count_vectorizer_example.py
@@ -0,0 +1,44 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+from pyspark import SparkContext
+from pyspark.sql import SQLContext
+# $example on$
+from pyspark.ml.feature import CountVectorizer
+# $example off$
+
+if __name__ == "__main__":
+ sc = SparkContext(appName="CountVectorizerExample")
+ sqlContext = SQLContext(sc)
+
+ # $example on$
+ # Input data: Each row is a bag of words with a ID.
+ df = sqlContext.createDataFrame([
+ (0, "a b c".split(" ")),
+ (1, "a b b c a".split(" "))
+ ], ["id", "words"])
+
+ # fit a CountVectorizerModel from the corpus.
+ cv = CountVectorizer(inputCol="words", outputCol="features", vocabSize=3, minDF=2.0)
+ model = cv.fit(df)
+ result = model.transform(df)
+ result.show()
+ # $example off$
+
+ sc.stop()
diff --git a/examples/src/main/python/ml/dct_example.py b/examples/src/main/python/ml/dct_example.py
new file mode 100644
index 0000000000..264d47f404
--- /dev/null
+++ b/examples/src/main/python/ml/dct_example.py
@@ -0,0 +1,45 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+from pyspark import SparkContext
+from pyspark.sql import SQLContext
+# $example on$
+from pyspark.ml.feature import DCT
+from pyspark.mllib.linalg import Vectors
+# $example off$
+
+if __name__ == "__main__":
+ sc = SparkContext(appName="DCTExample")
+ sqlContext = SQLContext(sc)
+
+ # $example on$
+ df = sqlContext.createDataFrame([
+ (Vectors.dense([0.0, 1.0, -2.0, 3.0]),),
+ (Vectors.dense([-1.0, 2.0, 4.0, -7.0]),),
+ (Vectors.dense([14.0, -2.0, -5.0, 1.0]),)], ["features"])
+
+ dct = DCT(inverse=False, inputCol="features", outputCol="featuresDCT")
+
+ dctDf = dct.transform(df)
+
+ for dcts in dctDf.select("featuresDCT").take(3):
+ print(dcts)
+ # $example off$
+
+ sc.stop()
diff --git a/examples/src/main/python/ml/max_abs_scaler_example.py b/examples/src/main/python/ml/max_abs_scaler_example.py
new file mode 100644
index 0000000000..d9b69eef1c
--- /dev/null
+++ b/examples/src/main/python/ml/max_abs_scaler_example.py
@@ -0,0 +1,43 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+from pyspark import SparkContext
+from pyspark.sql import SQLContext
+# $example on$
+from pyspark.ml.feature import MaxAbsScaler
+# $example off$
+
+if __name__ == "__main__":
+ sc = SparkContext(appName="MaxAbsScalerExample")
+ sqlContext = SQLContext(sc)
+
+ # $example on$
+ dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
+
+ scaler = MaxAbsScaler(inputCol="features", outputCol="scaledFeatures")
+
+ # Compute summary statistics and generate MaxAbsScalerModel
+ scalerModel = scaler.fit(dataFrame)
+
+ # rescale each feature to range [-1, 1].
+ scaledData = scalerModel.transform(dataFrame)
+ scaledData.show()
+ # $example off$
+
+ sc.stop()
diff --git a/examples/src/main/python/ml/min_max_scaler_example.py b/examples/src/main/python/ml/min_max_scaler_example.py
new file mode 100644
index 0000000000..2f8e4ade46
--- /dev/null
+++ b/examples/src/main/python/ml/min_max_scaler_example.py
@@ -0,0 +1,43 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+from pyspark import SparkContext
+from pyspark.sql import SQLContext
+# $example on$
+from pyspark.ml.feature import MinMaxScaler
+# $example off$
+
+if __name__ == "__main__":
+ sc = SparkContext(appName="MinMaxScalerExample")
+ sqlContext = SQLContext(sc)
+
+ # $example on$
+ dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
+
+ scaler = MinMaxScaler(inputCol="features", outputCol="scaledFeatures")
+
+ # Compute summary statistics and generate MinMaxScalerModel
+ scalerModel = scaler.fit(dataFrame)
+
+ # rescale each feature to range [min, max].
+ scaledData = scalerModel.transform(dataFrame)
+ scaledData.show()
+ # $example off$
+
+ sc.stop()
diff --git a/examples/src/main/python/ml/naive_bayes_example.py b/examples/src/main/python/ml/naive_bayes_example.py
new file mode 100644
index 0000000000..db8fbea9bf
--- /dev/null
+++ b/examples/src/main/python/ml/naive_bayes_example.py
@@ -0,0 +1,53 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+from pyspark import SparkContext
+from pyspark.sql import SQLContext
+# $example on$
+from pyspark.ml.classification import NaiveBayes
+from pyspark.ml.evaluation import MulticlassClassificationEvaluator
+# $example off$
+
+if __name__ == "__main__":
+
+ sc = SparkContext(appName="naive_bayes_example")
+ sqlContext = SQLContext(sc)
+
+ # $example on$
+ # Load training data
+ data = sqlContext.read.format("libsvm") \
+ .load("data/mllib/sample_libsvm_data.txt")
+ # Split the data into train and test
+ splits = data.randomSplit([0.6, 0.4], 1234)
+ train = splits[0]
+ test = splits[1]
+
+ # create the trainer and set its parameters
+ nb = NaiveBayes(smoothing=1.0, modelType="multinomial")
+
+ # train the model
+ model = nb.fit(train)
+ # compute precision on the test set
+ result = model.transform(test)
+ predictionAndLabels = result.select("prediction", "label")
+ evaluator = MulticlassClassificationEvaluator(metricName="precision")
+ print("Precision:" + str(evaluator.evaluate(predictionAndLabels)))
+ # $example off$
+
+ sc.stop()
diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
index 3da5236745..af5a815f6e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
@@ -21,8 +21,8 @@ package org.apache.spark.examples
import org.apache.spark.{SparkConf, SparkContext}
/**
- * Usage: BroadcastTest [slices] [numElem] [blockSize]
- */
+ * Usage: BroadcastTest [slices] [numElem] [blockSize]
+ */
object BroadcastTest {
def main(args: Array[String]) {
diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala
index 973b005f91..ca4eea2356 100644
--- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala
@@ -106,9 +106,8 @@ object CassandraCQLTest {
println("Count: " + casRdd.count)
val productSaleRDD = casRdd.map {
- case (key, value) => {
+ case (key, value) =>
(ByteBufferUtil.string(value.get("prod_id")), ByteBufferUtil.toInt(value.get("quantity")))
- }
}
val aggregatedRDD = productSaleRDD.reduceByKey(_ + _)
aggregatedRDD.collect().foreach {
@@ -116,11 +115,10 @@ object CassandraCQLTest {
}
val casoutputCF = aggregatedRDD.map {
- case (productId, saleCount) => {
+ case (productId, saleCount) =>
val outKey = Collections.singletonMap("prod_id", ByteBufferUtil.bytes(productId))
val outVal = Collections.singletonList(ByteBufferUtil.bytes(saleCount))
(outKey, outVal)
- }
}
casoutputCF.saveAsNewAPIHadoopFile(
diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala
index 6a8f73ad00..eff840d36e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala
@@ -90,9 +90,8 @@ object CassandraTest {
// Let us first get all the paragraphs from the retrieved rows
val paraRdd = casRdd.map {
- case (key, value) => {
+ case (key, value) =>
ByteBufferUtil.string(value.get(ByteBufferUtil.bytes("para")).value())
- }
}
// Lets get the word count in paras
@@ -103,7 +102,7 @@ object CassandraTest {
}
counts.map {
- case (word, count) => {
+ case (word, count) =>
val colWord = new org.apache.cassandra.thrift.Column()
colWord.setName(ByteBufferUtil.bytes("word"))
colWord.setValue(ByteBufferUtil.bytes(word))
@@ -122,7 +121,6 @@ object CassandraTest {
mutations.get(1).setColumn_or_supercolumn(new ColumnOrSuperColumn())
mutations.get(1).column_or_supercolumn.setColumn(colCount)
(outputkey, mutations)
- }
}.saveAsNewAPIHadoopFile("casDemo", classOf[ByteBuffer], classOf[List[Mutation]],
classOf[ColumnFamilyOutputFormat], job.getConfiguration)
diff --git a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala
index 743fc13db7..7bf023667d 100644
--- a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala
@@ -25,16 +25,16 @@ import scala.io.Source._
import org.apache.spark.{SparkConf, SparkContext}
/**
- * Simple test for reading and writing to a distributed
- * file system. This example does the following:
- *
- * 1. Reads local file
- * 2. Computes word count on local file
- * 3. Writes local file to a DFS
- * 4. Reads the file back from the DFS
- * 5. Computes word count on the file using Spark
- * 6. Compares the word count results
- */
+ * Simple test for reading and writing to a distributed
+ * file system. This example does the following:
+ *
+ * 1. Reads local file
+ * 2. Computes word count on local file
+ * 3. Writes local file to a DFS
+ * 4. Reads the file back from the DFS
+ * 5. Computes word count on the file using Spark
+ * 6. Compares the word count results
+ */
object DFSReadWriteTest {
private var localFilePath: File = new File(".")
diff --git a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala
index a2d59a1c95..d12ef642bd 100644
--- a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala
@@ -22,8 +22,10 @@ import scala.collection.JavaConverters._
import org.apache.spark.util.Utils
-/** Prints out environmental information, sleeps, and then exits. Made to
- * test driver submission in the standalone scheduler. */
+/**
+ * Prints out environmental information, sleeps, and then exits. Made to
+ * test driver submission in the standalone scheduler.
+ */
object DriverSubmissionTest {
def main(args: Array[String]) {
if (args.length < 1) {
diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala
index 08b6c717d4..4db229b5de 100644
--- a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala
@@ -23,8 +23,8 @@ import java.util.Random
import org.apache.spark.{SparkConf, SparkContext}
/**
- * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers]
- */
+ * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers]
+ */
object GroupByTest {
def main(args: Array[String]) {
val sparkConf = new SparkConf().setAppName("GroupBy Test")
diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala
index af5f216f28..fa10101955 100644
--- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala
@@ -104,16 +104,14 @@ object LocalALS {
def main(args: Array[String]) {
args match {
- case Array(m, u, f, iters) => {
+ case Array(m, u, f, iters) =>
M = m.toInt
U = u.toInt
F = f.toInt
ITERATIONS = iters.toInt
- }
- case _ => {
+ case _ =>
System.err.println("Usage: LocalALS <M> <U> <F> <iters>")
System.exit(1)
- }
}
showWarning()
diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala
index 134c3d1d63..3eb0c27723 100644
--- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala
@@ -22,8 +22,8 @@ import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.rdd.RDD
/**
- * Usage: MultiBroadcastTest [slices] [numElem]
- */
+ * Usage: MultiBroadcastTest [slices] [numElem]
+ */
object MultiBroadcastTest {
def main(args: Array[String]) {
diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala
index 7c09664c2f..ec07e6323e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala
@@ -23,8 +23,8 @@ import java.util.Random
import org.apache.spark.{SparkConf, SparkContext}
/**
- * Usage: SimpleSkewedGroupByTest [numMappers] [numKVPairs] [valSize] [numReducers] [ratio]
- */
+ * Usage: SimpleSkewedGroupByTest [numMappers] [numKVPairs] [valSize] [numReducers] [ratio]
+ */
object SimpleSkewedGroupByTest {
def main(args: Array[String]) {
diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala
index d498af9c39..8e4c2b6229 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala
@@ -23,8 +23,8 @@ import java.util.Random
import org.apache.spark.{SparkConf, SparkContext}
/**
- * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers]
- */
+ * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers]
+ */
object SkewedGroupByTest {
def main(args: Array[String]) {
val sparkConf = new SparkConf().setAppName("GroupBy Test")
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
index c1f63c6a1d..8d127f9b35 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.param.{IntParam, ParamMap}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
/**
* A simple example demonstrating how to write your own learning algorithm using Estimator,
@@ -120,7 +120,7 @@ private class MyLogisticRegression(override val uid: String)
def setMaxIter(value: Int): this.type = set(maxIter, value)
// This method is used by fit()
- override protected def train(dataset: DataFrame): MyLogisticRegressionModel = {
+ override protected def train(dataset: Dataset[_]): MyLogisticRegressionModel = {
// Extract columns from data using helper method.
val oldDataset = extractLabeledPoints(dataset)
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala
index af90652b55..7af011571f 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala
@@ -23,8 +23,8 @@ import org.apache.spark.{SparkConf, SparkContext}
// $example on$
import org.apache.spark.ml.clustering.KMeans
import org.apache.spark.mllib.linalg.Vectors
-// $example off$
import org.apache.spark.sql.{DataFrame, SQLContext}
+// $example off$
/**
* An example demonstrating a k-means clustering.
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala
new file mode 100644
index 0000000000..5ea1270c97
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off println
+package org.apache.spark.examples.ml
+
+import org.apache.spark.{SparkConf, SparkContext}
+// $example on$
+import org.apache.spark.ml.classification.{NaiveBayes}
+import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
+// $example off$
+import org.apache.spark.sql.SQLContext
+
+object NaiveBayesExample {
+ def main(args: Array[String]): Unit = {
+ val conf = new SparkConf().setAppName("NaiveBayesExample")
+ val sc = new SparkContext(conf)
+ val sqlContext = new SQLContext(sc)
+ // $example on$
+ // Load the data stored in LIBSVM format as a DataFrame.
+ val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
+
+ // Split the data into training and test sets (30% held out for testing)
+ val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
+
+ // Train a NaiveBayes model.
+ val model = new NaiveBayes()
+ .fit(trainingData)
+
+ // Select example rows to display.
+ val predictions = model.transform(testData)
+ predictions.show()
+
+ // Select (prediction, true label) and compute test error
+ val evaluator = new MulticlassClassificationEvaluator()
+ .setLabelCol("label")
+ .setPredictionCol("prediction")
+ .setMetricName("precision")
+ val precision = evaluator.evaluate(predictions)
+ println("Precision:" + precision)
+ // $example off$
+ }
+}
+// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala
index a0bb5dabf4..0b5d31c0ff 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala
@@ -118,17 +118,15 @@ object OneVsRestExample {
val inputData = sqlContext.read.format("libsvm").load(params.input)
// compute the train/test split: if testInput is not provided use part of input.
val data = params.testInput match {
- case Some(t) => {
+ case Some(t) =>
// compute the number of features in the training set.
val numFeatures = inputData.first().getAs[Vector](1).size
val testData = sqlContext.read.option("numFeatures", numFeatures.toString)
.format("libsvm").load(t)
Array[DataFrame](inputData, testData)
- }
- case None => {
+ case None =>
val f = params.fracTest
inputData.randomSplit(Array(1 - f, f), seed = 12345)
- }
}
val Array(train, test) = data.map(_.cache())
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index c263f4f595..ee811d3aa1 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -180,7 +180,7 @@ object DecisionTreeRunner {
}
// For classification, re-index classes if needed.
val (examples, classIndexMap, numClasses) = algo match {
- case Classification => {
+ case Classification =>
// classCounts: class --> # examples in class
val classCounts = origExamples.map(_.label).countByValue()
val sortedClasses = classCounts.keys.toList.sorted
@@ -209,7 +209,6 @@ object DecisionTreeRunner {
println(s"$c\t$frac\t${classCounts(c)}")
}
(examples, classIndexMap, numClasses)
- }
case Regression =>
(origExamples, null, 0)
case _ =>
@@ -225,7 +224,7 @@ object DecisionTreeRunner {
case "libsvm" => MLUtils.loadLibSVMFile(sc, testInput, numFeatures)
}
algo match {
- case Classification => {
+ case Classification =>
// classCounts: class --> # examples in class
val testExamples = {
if (classIndexMap.isEmpty) {
@@ -235,7 +234,6 @@ object DecisionTreeRunner {
}
}
Array(examples, testExamples)
- }
case Regression =>
Array(examples, origTestExamples)
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala
index 05f8e65d65..bb2af9cd72 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala
@@ -116,7 +116,7 @@ object RecoverableNetworkWordCount {
val lines = ssc.socketTextStream(ip, port)
val words = lines.flatMap(_.split(" "))
val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
- wordCounts.foreachRDD((rdd: RDD[(String, Int)], time: Time) => {
+ wordCounts.foreachRDD { (rdd: RDD[(String, Int)], time: Time) =>
// Get or register the blacklist Broadcast
val blacklist = WordBlacklist.getInstance(rdd.sparkContext)
// Get or register the droppedWordsCounter Accumulator
@@ -135,13 +135,13 @@ object RecoverableNetworkWordCount {
println("Dropped " + droppedWordsCounter.value + " word(s) totally")
println("Appending to " + outputFile.getAbsolutePath)
Files.append(output + "\n", outputFile, Charset.defaultCharset())
- })
+ }
ssc
}
def main(args: Array[String]) {
if (args.length != 4) {
- System.err.println("You arguments were " + args.mkString("[", ", ", "]"))
+ System.err.println("Your arguments were " + args.mkString("[", ", ", "]"))
System.err.println(
"""
|Usage: RecoverableNetworkWordCount <hostname> <port> <checkpoint-directory>
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala
index 3727f8fe6a..918e124065 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala
@@ -59,7 +59,7 @@ object SqlNetworkWordCount {
val words = lines.flatMap(_.split(" "))
// Convert RDDs of the words DStream to DataFrame and run SQL query
- words.foreachRDD((rdd: RDD[String], time: Time) => {
+ words.foreachRDD { (rdd: RDD[String], time: Time) =>
// Get the singleton instance of SQLContext
val sqlContext = SQLContextSingleton.getInstance(rdd.sparkContext)
import sqlContext.implicits._
@@ -75,7 +75,7 @@ object SqlNetworkWordCount {
sqlContext.sql("select word, count(*) as total from words group by word")
println(s"========= $time =========")
wordCountsDataFrame.show()
- })
+ }
ssc.start()
ssc.awaitTermination()
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala
index 50216b9bd4..0ddd065f0d 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala
@@ -38,17 +38,18 @@ object PageView extends Serializable {
}
// scalastyle:off
-/** Generates streaming events to simulate page views on a website.
- *
- * This should be used in tandem with PageViewStream.scala. Example:
- *
- * To run the generator
- * `$ bin/run-example org.apache.spark.examples.streaming.clickstream.PageViewGenerator 44444 10`
- * To process the generated stream
- * `$ bin/run-example \
- * org.apache.spark.examples.streaming.clickstream.PageViewStream errorRatePerZipCode localhost 44444`
- *
- */
+/**
+ * Generates streaming events to simulate page views on a website.
+ *
+ * This should be used in tandem with PageViewStream.scala. Example:
+ *
+ * To run the generator
+ * `$ bin/run-example org.apache.spark.examples.streaming.clickstream.PageViewGenerator 44444 10`
+ * To process the generated stream
+ * `$ bin/run-example \
+ * org.apache.spark.examples.streaming.clickstream.PageViewStream errorRatePerZipCode localhost 44444`
+ *
+ */
// scalastyle:on
object PageViewGenerator {
val pages = Map("http://foo.com/" -> .7,
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala
index 773a2e5fc2..1ba093f57b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala
@@ -22,16 +22,17 @@ import org.apache.spark.examples.streaming.StreamingExamples
import org.apache.spark.streaming.{Seconds, StreamingContext}
// scalastyle:off
-/** Analyses a streaming dataset of web page views. This class demonstrates several types of
- * operators available in Spark streaming.
- *
- * This should be used in tandem with PageViewStream.scala. Example:
- * To run the generator
- * `$ bin/run-example org.apache.spark.examples.streaming.clickstream.PageViewGenerator 44444 10`
- * To process the generated stream
- * `$ bin/run-example \
- * org.apache.spark.examples.streaming.clickstream.PageViewStream errorRatePerZipCode localhost 44444`
- */
+/**
+ * Analyses a streaming dataset of web page views. This class demonstrates several types of
+ * operators available in Spark streaming.
+ *
+ * This should be used in tandem with PageViewStream.scala. Example:
+ * To run the generator
+ * `$ bin/run-example org.apache.spark.examples.streaming.clickstream.PageViewGenerator 44444 10`
+ * To process the generated stream
+ * `$ bin/run-example \
+ * org.apache.spark.examples.streaming.clickstream.PageViewStream errorRatePerZipCode localhost 44444`
+ */
// scalastyle:on
object PageViewStream {
def main(args: Array[String]) {
diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml
index 1764aa9465..17fd7d781c 100644
--- a/external/docker-integration-tests/pom.xml
+++ b/external/docker-integration-tests/pom.xml
@@ -34,6 +34,13 @@
<sbt.project.name>docker-integration-tests</sbt.project.name>
</properties>
+ <repositories>
+ <repository>
+ <id>db2</id>
+ <url>https://app.camunda.com/nexus/content/repositories/public/</url>
+ </repository>
+ </repositories>
+
<dependencies>
<dependency>
<groupId>com.spotify</groupId>
@@ -180,5 +187,28 @@
</exclusions>
</dependency>
<!-- End Jersey dependencies -->
+
+ <!-- DB2 JCC driver manual installation instructions
+
+ You can build this datasource if you:
+ 1) have the DB2 artifacts installed in a local repo and supply the URL:
+ -Dmaven.repo.drivers=http://my.local.repo
+
+ 2) have a copy of the DB2 JCC driver and run the following commands :
+ mvn install:install-file -Dfile=${path to db2jcc4.jar} \
+ -DgroupId=com.ibm.db2 \
+ -DartifactId=db2jcc4 \
+ -Dversion=10.5 \
+ -Dpackaging=jar
+
+ Note: IBM DB2 JCC driver is available for download at
+ http://www-01.ibm.com/support/docview.wss?uid=swg21363866
+ -->
+ <dependency>
+ <groupId>com.ibm.db2.jcc</groupId>
+ <artifactId>db2jcc4</artifactId>
+ <version>10.5.0.5</version>
+ <type>jar</type>
+ </dependency>
</dependencies>
</project>
diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala
new file mode 100644
index 0000000000..4fe1ef6697
--- /dev/null
+++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import java.math.BigDecimal
+import java.sql.{Connection, Date, Timestamp}
+import java.util.Properties
+
+import org.scalatest._
+
+import org.apache.spark.tags.DockerTest
+
+@DockerTest
+@Ignore // AMPLab Jenkins needs to be updated before shared memory works on docker
+class DB2IntegrationSuite extends DockerJDBCIntegrationSuite {
+ override val db = new DatabaseOnDocker {
+ override val imageName = "lresende/db2express-c:10.5.0.5-3.10.0"
+ override val env = Map(
+ "DB2INST1_PASSWORD" -> "rootpass",
+ "LICENSE" -> "accept"
+ )
+ override val usesIpc = true
+ override val jdbcPort: Int = 50000
+ override def getJdbcUrl(ip: String, port: Int): String =
+ s"jdbc:db2://$ip:$port/foo:user=db2inst1;password=rootpass;"
+ override def getStartupProcessName: Option[String] = Some("db2start")
+ }
+
+ override def dataPreparation(conn: Connection): Unit = {
+ conn.prepareStatement("CREATE TABLE tbl (x INTEGER, y VARCHAR(8))").executeUpdate()
+ conn.prepareStatement("INSERT INTO tbl VALUES (42,'fred')").executeUpdate()
+ conn.prepareStatement("INSERT INTO tbl VALUES (17,'dave')").executeUpdate()
+
+ conn.prepareStatement("CREATE TABLE numbers (onebit BIT(1), tenbits BIT(10), "
+ + "small SMALLINT, med MEDIUMINT, nor INT, big BIGINT, deci DECIMAL(40,20), flt FLOAT, "
+ + "dbl DOUBLE)").executeUpdate()
+ conn.prepareStatement("INSERT INTO numbers VALUES (b'0', b'1000100101', "
+ + "17, 77777, 123456789, 123456789012345, 123456789012345.123456789012345, "
+ + "42.75, 1.0000000000000002)").executeUpdate()
+
+ conn.prepareStatement("CREATE TABLE dates (d DATE, t TIME, dt DATETIME, ts TIMESTAMP, "
+ + "yr YEAR)").executeUpdate()
+ conn.prepareStatement("INSERT INTO dates VALUES ('1991-11-09', '13:31:24', "
+ + "'1996-01-01 01:23:45', '2009-02-13 23:31:30', '2001')").executeUpdate()
+
+ // TODO: Test locale conversion for strings.
+ conn.prepareStatement("CREATE TABLE strings (a CHAR(10), b VARCHAR(10), c CLOB, d BLOB, "
+ + "e CHAR FOR BIT DATA)").executeUpdate()
+ conn.prepareStatement("INSERT INTO strings VALUES ('the', 'quick', 'brown', 'fox', 'jumps'")
+ .executeUpdate()
+ }
+
+ test("Basic test") {
+ val df = sqlContext.read.jdbc(jdbcUrl, "tbl", new Properties)
+ val rows = df.collect()
+ assert(rows.length == 2)
+ val types = rows(0).toSeq.map(x => x.getClass.toString)
+ assert(types.length == 2)
+ assert(types(0).equals("class java.lang.Integer"))
+ assert(types(1).equals("class java.lang.String"))
+ }
+
+ test("Numeric types") {
+ val df = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties)
+ val rows = df.collect()
+ assert(rows.length == 1)
+ val types = rows(0).toSeq.map(x => x.getClass.toString)
+ assert(types.length == 9)
+ assert(types(0).equals("class java.lang.Boolean"))
+ assert(types(1).equals("class java.lang.Long"))
+ assert(types(2).equals("class java.lang.Integer"))
+ assert(types(3).equals("class java.lang.Integer"))
+ assert(types(4).equals("class java.lang.Integer"))
+ assert(types(5).equals("class java.lang.Long"))
+ assert(types(6).equals("class java.math.BigDecimal"))
+ assert(types(7).equals("class java.lang.Double"))
+ assert(types(8).equals("class java.lang.Double"))
+ assert(rows(0).getBoolean(0) == false)
+ assert(rows(0).getLong(1) == 0x225)
+ assert(rows(0).getInt(2) == 17)
+ assert(rows(0).getInt(3) == 77777)
+ assert(rows(0).getInt(4) == 123456789)
+ assert(rows(0).getLong(5) == 123456789012345L)
+ val bd = new BigDecimal("123456789012345.12345678901234500000")
+ assert(rows(0).getAs[BigDecimal](6).equals(bd))
+ assert(rows(0).getDouble(7) == 42.75)
+ assert(rows(0).getDouble(8) == 1.0000000000000002)
+ }
+
+ test("Date types") {
+ val df = sqlContext.read.jdbc(jdbcUrl, "dates", new Properties)
+ val rows = df.collect()
+ assert(rows.length == 1)
+ val types = rows(0).toSeq.map(x => x.getClass.toString)
+ assert(types.length == 5)
+ assert(types(0).equals("class java.sql.Date"))
+ assert(types(1).equals("class java.sql.Timestamp"))
+ assert(types(2).equals("class java.sql.Timestamp"))
+ assert(types(3).equals("class java.sql.Timestamp"))
+ assert(types(4).equals("class java.sql.Date"))
+ assert(rows(0).getAs[Date](0).equals(Date.valueOf("1991-11-09")))
+ assert(rows(0).getAs[Timestamp](1).equals(Timestamp.valueOf("1970-01-01 13:31:24")))
+ assert(rows(0).getAs[Timestamp](2).equals(Timestamp.valueOf("1996-01-01 01:23:45")))
+ assert(rows(0).getAs[Timestamp](3).equals(Timestamp.valueOf("2009-02-13 23:31:30")))
+ assert(rows(0).getAs[Date](4).equals(Date.valueOf("2001-01-01")))
+ }
+
+ test("String types") {
+ val df = sqlContext.read.jdbc(jdbcUrl, "strings", new Properties)
+ val rows = df.collect()
+ assert(rows.length == 1)
+ val types = rows(0).toSeq.map(x => x.getClass.toString)
+ assert(types.length == 9)
+ assert(types(0).equals("class java.lang.String"))
+ assert(types(1).equals("class java.lang.String"))
+ assert(types(2).equals("class java.lang.String"))
+ assert(types(3).equals("class java.lang.String"))
+ assert(types(4).equals("class java.lang.String"))
+ assert(types(5).equals("class java.lang.String"))
+ assert(types(6).equals("class [B"))
+ assert(types(7).equals("class [B"))
+ assert(types(8).equals("class [B"))
+ assert(rows(0).getString(0).equals("the"))
+ assert(rows(0).getString(1).equals("quick"))
+ assert(rows(0).getString(2).equals("brown"))
+ assert(rows(0).getString(3).equals("fox"))
+ assert(rows(0).getString(4).equals("jumps"))
+ assert(rows(0).getString(5).equals("over"))
+ assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](116, 104, 101, 0)))
+ assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](7), Array[Byte](108, 97, 122, 121)))
+ assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](8), Array[Byte](100, 111, 103)))
+ }
+
+ test("Basic write test") {
+ val df1 = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties)
+ val df2 = sqlContext.read.jdbc(jdbcUrl, "dates", new Properties)
+ val df3 = sqlContext.read.jdbc(jdbcUrl, "strings", new Properties)
+ df1.write.jdbc(jdbcUrl, "numberscopy", new Properties)
+ df2.write.jdbc(jdbcUrl, "datescopy", new Properties)
+ df3.write.jdbc(jdbcUrl, "stringscopy", new Properties)
+ }
+}
diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala
index f73231fc80..c36f4d5f95 100644
--- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala
+++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala
@@ -45,6 +45,11 @@ abstract class DatabaseOnDocker {
val env: Map[String, String]
/**
+ * Wheather or not to use ipc mode for shared memory when starting docker image
+ */
+ val usesIpc: Boolean
+
+ /**
* The container-internal JDBC port that the database listens on.
*/
val jdbcPort: Int
@@ -53,6 +58,11 @@ abstract class DatabaseOnDocker {
* Return a JDBC URL that connects to the database running at the given IP address and port.
*/
def getJdbcUrl(ip: String, port: Int): String
+
+ /**
+ * Optional process to run when container starts
+ */
+ def getStartupProcessName: Option[String]
}
abstract class DockerJDBCIntegrationSuite
@@ -97,17 +107,23 @@ abstract class DockerJDBCIntegrationSuite
val dockerIp = DockerUtils.getDockerIp()
val hostConfig: HostConfig = HostConfig.builder()
.networkMode("bridge")
+ .ipcMode(if (db.usesIpc) "host" else "")
.portBindings(
Map(s"${db.jdbcPort}/tcp" -> List(PortBinding.of(dockerIp, externalPort)).asJava).asJava)
.build()
// Create the database container:
- val config = ContainerConfig.builder()
+ val containerConfigBuilder = ContainerConfig.builder()
.image(db.imageName)
.networkDisabled(false)
.env(db.env.map { case (k, v) => s"$k=$v" }.toSeq.asJava)
.hostConfig(hostConfig)
.exposedPorts(s"${db.jdbcPort}/tcp")
- .build()
+ if(db.getStartupProcessName.isDefined) {
+ containerConfigBuilder
+ .cmd(db.getStartupProcessName.get)
+ }
+ val config = containerConfigBuilder.build()
+ // Create the database container:
containerId = docker.createContainer(config).id
// Start the container and wait until the database can accept JDBC connections:
docker.startContainer(containerId)
diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala
index c68e4dc493..a70ed98b52 100644
--- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala
+++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala
@@ -30,9 +30,11 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite {
override val env = Map(
"MYSQL_ROOT_PASSWORD" -> "rootpass"
)
+ override val usesIpc = false
override val jdbcPort: Int = 3306
override def getJdbcUrl(ip: String, port: Int): String =
s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass"
+ override def getStartupProcessName: Option[String] = None
}
override def dataPreparation(conn: Connection): Unit = {
diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala
index 8a0f938f7e..2fc174eb1b 100644
--- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala
+++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala
@@ -52,9 +52,11 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
override val env = Map(
"ORACLE_ROOT_PASSWORD" -> "oracle"
)
+ override val usesIpc = false
override val jdbcPort: Int = 1521
override def getJdbcUrl(ip: String, port: Int): String =
s"jdbc:oracle:thin:system/oracle@//$ip:$port/xe"
+ override def getStartupProcessName: Option[String] = None
}
override def dataPreparation(conn: Connection): Unit = {
diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
index d55cdcf28b..79dd70116e 100644
--- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
+++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
@@ -32,9 +32,11 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
override val env = Map(
"POSTGRES_PASSWORD" -> "rootpass"
)
+ override val usesIpc = false
override val jdbcPort = 5432
override def getJdbcUrl(ip: String, port: Int): String =
s"jdbc:postgresql://$ip:$port/postgres?user=postgres&password=rootpass"
+ override def getStartupProcessName: Option[String] = None
}
override def dataPreparation(conn: Connection): Unit = {
diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala
index 719fca0938..8050ec357e 100644
--- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala
+++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala
@@ -129,9 +129,9 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha
* @param success Whether the batch was successful or not.
*/
private def completeTransaction(sequenceNumber: CharSequence, success: Boolean) {
- removeAndGetProcessor(sequenceNumber).foreach(processor => {
+ removeAndGetProcessor(sequenceNumber).foreach { processor =>
processor.batchProcessed(success)
- })
+ }
}
/**
diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala
index 14dffb15fe..41f27e9376 100644
--- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala
+++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala
@@ -88,23 +88,23 @@ class SparkSink extends AbstractSink with Logging with Configurable {
// dependencies which are being excluded in the build. In practice,
// Netty dependencies are already available on the JVM as Flume would have pulled them in.
serverOpt = Option(new NettyServer(responder, new InetSocketAddress(hostname, port)))
- serverOpt.foreach(server => {
+ serverOpt.foreach { server =>
logInfo("Starting Avro server for sink: " + getName)
server.start()
- })
+ }
super.start()
}
override def stop() {
logInfo("Stopping Spark Sink: " + getName)
- handler.foreach(callbackHandler => {
+ handler.foreach { callbackHandler =>
callbackHandler.shutdown()
- })
- serverOpt.foreach(server => {
+ }
+ serverOpt.foreach { server =>
logInfo("Stopping Avro Server for sink: " + getName)
server.close()
server.join()
- })
+ }
blockingLatch.countDown()
super.stop()
}
diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala
index b15c2097e5..19e736f016 100644
--- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala
+++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala
@@ -110,7 +110,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String,
eventBatch.setErrorMsg("Something went wrong. Channel was " +
"unable to create a transaction!")
}
- txOpt.foreach(tx => {
+ txOpt.foreach { tx =>
tx.begin()
val events = new util.ArrayList[SparkSinkEvent](maxBatchSize)
val loop = new Breaks
@@ -145,7 +145,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String,
// At this point, the events are available, so fill them into the event batch
eventBatch = new EventBatch("", seqNum, events)
}
- })
+ }
} catch {
case interrupted: InterruptedException =>
// Don't pollute logs if the InterruptedException came from this being stopped
@@ -156,9 +156,9 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String,
logWarning("Error while processing transaction.", e)
eventBatch.setErrorMsg(e.getMessage)
try {
- txOpt.foreach(tx => {
+ txOpt.foreach { tx =>
rollbackAndClose(tx, close = true)
- })
+ }
} finally {
txOpt = None
}
@@ -174,7 +174,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String,
*/
private def processAckOrNack() {
batchAckLatch.await(transactionTimeout, TimeUnit.SECONDS)
- txOpt.foreach(tx => {
+ txOpt.foreach { tx =>
if (batchSuccess) {
try {
logDebug("Committing transaction")
@@ -197,7 +197,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String,
// cause issues. This is required to ensure the TransactionProcessor instance is not leaked
parent.removeAndGetProcessor(seqNum)
}
- })
+ }
}
/**
diff --git a/external/flume-sink/src/test/resources/log4j.properties b/external/flume-sink/src/test/resources/log4j.properties
index 42df8792f1..1e3f163f95 100644
--- a/external/flume-sink/src/test/resources/log4j.properties
+++ b/external/flume-sink/src/test/resources/log4j.properties
@@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
-log4j.logger.org.spark-project.jetty=WARN
+log4j.logger.org.spark_project.jetty=WARN
diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
index 7dc9606913..13aa817492 100644
--- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
+++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
@@ -130,8 +130,10 @@ class FlumeEventServer(receiver: FlumeReceiver) extends AvroSourceProtocol {
}
}
-/** A NetworkReceiver which listens for events using the
- * Flume Avro interface. */
+/**
+ * A NetworkReceiver which listens for events using the
+ * Flume Avro interface.
+ */
private[streaming]
class FlumeReceiver(
host: String,
@@ -185,13 +187,14 @@ class FlumeReceiver(
override def preferredLocation: Option[String] = Option(host)
- /** A Netty Pipeline factory that will decompress incoming data from
- * and the Netty client and compress data going back to the client.
- *
- * The compression on the return is required because Flume requires
- * a successful response to indicate it can remove the event/batch
- * from the configured channel
- */
+ /**
+ * A Netty Pipeline factory that will decompress incoming data from
+ * and the Netty client and compress data going back to the client.
+ *
+ * The compression on the return is required because Flume requires
+ * a successful response to indicate it can remove the event/batch
+ * from the configured channel
+ */
private[streaming]
class CompressionChannelPipelineFactory extends ChannelPipelineFactory {
def getPipeline(): ChannelPipeline = {
diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala
index 250bfc1718..54565840fa 100644
--- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala
+++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala
@@ -79,11 +79,11 @@ private[streaming] class FlumePollingReceiver(
override def onStart(): Unit = {
// Create the connections to each Flume agent.
- addresses.foreach(host => {
+ addresses.foreach { host =>
val transceiver = new NettyTransceiver(host, channelFactory)
val client = SpecificRequestor.getClient(classOf[SparkFlumeProtocol.Callback], transceiver)
connections.add(new FlumeConnection(transceiver, client))
- })
+ }
for (i <- 0 until parallelism) {
logInfo("Starting Flume Polling Receiver worker threads..")
// Threads that pull data from Flume.
diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala
index 1a96df6e94..6a4dafb8ed 100644
--- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala
+++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala
@@ -123,9 +123,9 @@ private[flume] class PollingFlumeTestUtils {
val latch = new CountDownLatch(batchCount * channels.size)
sinks.foreach(_.countdownWhenBatchReceived(latch))
- channels.foreach(channel => {
+ channels.foreach { channel =>
executorCompletion.submit(new TxnSubmitter(channel))
- })
+ }
for (i <- 0 until channels.size) {
executorCompletion.take()
diff --git a/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java b/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java
index 3b5e0c7746..ada05f203b 100644
--- a/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java
+++ b/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java
@@ -27,10 +27,11 @@ public class JavaFlumeStreamSuite extends LocalJavaStreamingContext {
@Test
public void testFlumeStream() {
// tests the API, does not actually test data receiving
- JavaReceiverInputDStream<SparkFlumeEvent> test1 = FlumeUtils.createStream(ssc, "localhost", 12345);
- JavaReceiverInputDStream<SparkFlumeEvent> test2 = FlumeUtils.createStream(ssc, "localhost", 12345,
- StorageLevel.MEMORY_AND_DISK_SER_2());
- JavaReceiverInputDStream<SparkFlumeEvent> test3 = FlumeUtils.createStream(ssc, "localhost", 12345,
- StorageLevel.MEMORY_AND_DISK_SER_2(), false);
+ JavaReceiverInputDStream<SparkFlumeEvent> test1 = FlumeUtils.createStream(ssc, "localhost",
+ 12345);
+ JavaReceiverInputDStream<SparkFlumeEvent> test2 = FlumeUtils.createStream(ssc, "localhost",
+ 12345, StorageLevel.MEMORY_AND_DISK_SER_2());
+ JavaReceiverInputDStream<SparkFlumeEvent> test3 = FlumeUtils.createStream(ssc, "localhost",
+ 12345, StorageLevel.MEMORY_AND_DISK_SER_2(), false);
}
}
diff --git a/external/flume/src/test/resources/log4j.properties b/external/flume/src/test/resources/log4j.properties
index 75e3b53a09..fd51f8faf5 100644
--- a/external/flume/src/test/resources/log4j.properties
+++ b/external/flume/src/test/resources/log4j.properties
@@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
-log4j.logger.org.spark-project.jetty=WARN
+log4j.logger.org.spark_project.jetty=WARN
diff --git a/external/java8-tests/README.md b/external/java8-tests/README.md
index dc9e87f2ee..aa87901695 100644
--- a/external/java8-tests/README.md
+++ b/external/java8-tests/README.md
@@ -8,16 +8,14 @@ to your Java location. The set-up depends a bit on the build system:
`-java-home` to the sbt launch script. If a Java 8 JDK is detected sbt will automatically
include the Java 8 test project.
- `$ JAVA_HOME=/opt/jdk1.8.0/ build/sbt clean "test-only org.apache.spark.Java8APISuite"`
+ `$ JAVA_HOME=/opt/jdk1.8.0/ build/sbt clean java8-tests/test
* For Maven users,
- Maven users can also refer to their Java 8 directory using JAVA_HOME. However, Maven will not
- automatically detect the presence of a Java 8 JDK, so a special build profile `-Pjava8-tests`
- must be used.
+ Maven users can also refer to their Java 8 directory using JAVA_HOME.
`$ JAVA_HOME=/opt/jdk1.8.0/ mvn clean install -DskipTests`
- `$ JAVA_HOME=/opt/jdk1.8.0/ mvn test -Pjava8-tests -DwildcardSuites=org.apache.spark.Java8APISuite`
+ `$ JAVA_HOME=/opt/jdk1.8.0/ mvn -pl :java8-tests_2.11 test`
Note that the above command can only be run from project root directory since this module
depends on core and the test-jars of core and streaming. This means an install step is
diff --git a/external/java8-tests/pom.xml b/external/java8-tests/pom.xml
index 0ad9c5303a..1ea9196e9d 100644
--- a/external/java8-tests/pom.xml
+++ b/external/java8-tests/pom.xml
@@ -27,7 +27,7 @@
<groupId>org.apache.spark</groupId>
<artifactId>java8-tests_2.11</artifactId>
<packaging>pom</packaging>
- <name>Spark Project Java8 Tests POM</name>
+ <name>Spark Project Java 8 Tests</name>
<properties>
<sbt.project.name>java8-tests</sbt.project.name>
@@ -60,15 +60,22 @@
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
+ <artifactId>spark-sql_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-sql_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
<artifactId>spark-test-tags_${scala.binary.version}</artifactId>
</dependency>
</dependencies>
- <profiles>
- <profile>
- <id>java8-tests</id>
- </profile>
- </profiles>
<build>
<plugins>
<plugin>
@@ -87,74 +94,26 @@
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-surefire-plugin</artifactId>
- <executions>
- <execution>
- <id>test</id>
- <goals>
- <goal>test</goal>
- </goals>
- </execution>
- </executions>
- <configuration>
- <systemPropertyVariables>
- <!-- For some reason surefire isn't setting this log4j file on the
- test classpath automatically. So we add it manually. -->
- <log4j.configuration>
- file:src/test/resources/log4j.properties
- </log4j.configuration>
- </systemPropertyVariables>
- <skipTests>false</skipTests>
- <includes>
- <include>**/Suite*.java</include>
- <include>**/*Suite.java</include>
- </includes>
- </configuration>
- </plugin>
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
- <executions>
- <execution>
- <id>test-compile-first</id>
- <phase>process-test-resources</phase>
- <goals>
- <goal>testCompile</goal>
- </goals>
- </execution>
- </executions>
<configuration>
- <fork>true</fork>
- <verbose>true</verbose>
<forceJavacCompilerUse>true</forceJavacCompilerUse>
<source>1.8</source>
- <compilerVersion>1.8</compilerVersion>
<target>1.8</target>
- <encoding>UTF-8</encoding>
- <maxmem>1024m</maxmem>
+ <compilerVersion>1.8</compilerVersion>
</configuration>
</plugin>
<plugin>
- <!-- disabled -->
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
- <executions>
- <execution>
- <phase>none</phase>
- </execution>
- <execution>
- <id>scala-compile-first</id>
- <phase>none</phase>
- </execution>
- <execution>
- <id>scala-test-compile-first</id>
- <phase>none</phase>
- </execution>
- <execution>
- <id>attach-scaladocs</id>
- <phase>none</phase>
- </execution>
- </executions>
+ <configuration>
+ <javacArgs>
+ <javacArg>-source</javacArg>
+ <javacArg>1.8</javacArg>
+ <javacArg>-target</javacArg>
+ <javacArg>1.8</javacArg>
+ <javacArg>-Xlint:all,-serial,-path</javacArg>
+ </javacArgs>
+ </configuration>
</plugin>
</plugins>
</build>
diff --git a/external/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java b/external/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java
index c0b58e713f..6ac5ca9cf5 100644
--- a/external/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java
+++ b/external/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java
@@ -188,7 +188,7 @@ public class Java8APISuite implements Serializable {
public void flatMap() {
JavaRDD<String> rdd = sc.parallelize(Arrays.asList("Hello World!",
"The quick brown fox jumps over the lazy dog."));
- JavaRDD<String> words = rdd.flatMap(x -> Arrays.asList(x.split(" ")));
+ JavaRDD<String> words = rdd.flatMap(x -> Arrays.asList(x.split(" ")).iterator());
Assert.assertEquals("Hello", words.first());
Assert.assertEquals(11, words.count());
@@ -198,7 +198,7 @@ public class Java8APISuite implements Serializable {
for (String word : s.split(" ")) {
pairs2.add(new Tuple2<>(word, word));
}
- return pairs2;
+ return pairs2.iterator();
});
Assert.assertEquals(new Tuple2<>("Hello", "Hello"), pairs.first());
@@ -209,7 +209,7 @@ public class Java8APISuite implements Serializable {
for (String word : s.split(" ")) {
lengths.add((double) word.length());
}
- return lengths;
+ return lengths.iterator();
});
Assert.assertEquals(5.0, doubles.first(), 0.01);
@@ -227,7 +227,7 @@ public class Java8APISuite implements Serializable {
// Regression test for SPARK-668:
JavaPairRDD<String, Integer> swapped =
- pairRDD.flatMapToPair(x -> Collections.singletonList(x.swap()));
+ pairRDD.flatMapToPair(x -> Collections.singletonList(x.swap()).iterator());
swapped.collect();
// There was never a bug here, but it's worth testing:
@@ -242,7 +242,7 @@ public class Java8APISuite implements Serializable {
while (iter.hasNext()) {
sum += iter.next();
}
- return Collections.singletonList(sum);
+ return Collections.singletonList(sum).iterator();
});
Assert.assertEquals("[3, 7]", partitionSums.collect().toString());
diff --git a/external/java8-tests/src/test/java/org/apache/spark/sql/Java8DatasetAggregatorSuite.java b/external/java8-tests/src/test/java/org/apache/spark/sql/Java8DatasetAggregatorSuite.java
new file mode 100644
index 0000000000..23abfa3970
--- /dev/null
+++ b/external/java8-tests/src/test/java/org/apache/spark/sql/Java8DatasetAggregatorSuite.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.sources;
+
+import java.util.Arrays;
+
+import org.junit.Assert;
+import org.junit.Test;
+import scala.Tuple2;
+
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.KeyValueGroupedDataset;
+import org.apache.spark.sql.expressions.java.typed;
+
+/**
+ * Suite that replicates tests in JavaDatasetAggregatorSuite using lambda syntax.
+ */
+public class Java8DatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase {
+ @Test
+ public void testTypedAggregationAverage() {
+ KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+ Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.avg(v -> (double)(v._2() * 2)));
+ Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 6.0)), agged.collectAsList());
+ }
+
+ @Test
+ public void testTypedAggregationCount() {
+ KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+ Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.count(v -> v));
+ Assert.assertEquals(Arrays.asList(tuple2("a", 2), tuple2("b", 1)), agged.collectAsList());
+ }
+
+ @Test
+ public void testTypedAggregationSumDouble() {
+ KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+ Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.sum(v -> (double)v._2()));
+ Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 3.0)), agged.collectAsList());
+ }
+
+ @Test
+ public void testTypedAggregationSumLong() {
+ KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+ Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.sumLong(v -> (long)v._2()));
+ Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
+ }
+}
diff --git a/external/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/external/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java
index 604d818ef1..d0fed303e6 100644
--- a/external/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java
+++ b/external/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java
@@ -29,6 +29,7 @@ import org.junit.Test;
import org.apache.spark.Accumulator;
import org.apache.spark.HashPartitioner;
+import org.apache.spark.api.java.Optional;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.PairFunction;
@@ -95,7 +96,7 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ
while (in.hasNext()) {
out = out + in.next().toUpperCase();
}
- return Lists.newArrayList(out);
+ return Lists.newArrayList(out).iterator();
});
JavaTestUtils.attachTestOutputStream(mapped);
List<List<String>> result = JavaTestUtils.runStreams(ssc, 2, 2);
@@ -351,7 +352,8 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ
Arrays.asList("a", "t", "h", "l", "e", "t", "i", "c", "s"));
JavaDStream<String> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
- JavaDStream<String> flatMapped = stream.flatMap(s -> Lists.newArrayList(s.split("(?!^)")));
+ JavaDStream<String> flatMapped = stream.flatMap(
+ s -> Lists.newArrayList(s.split("(?!^)")).iterator());
JavaTestUtils.attachTestOutputStream(flatMapped);
List<List<String>> result = JavaTestUtils.runStreams(ssc, 3, 3);
@@ -360,8 +362,8 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ
@Test
public void testForeachRDD() {
- final Accumulator<Integer> accumRdd = ssc.sc().accumulator(0);
- final Accumulator<Integer> accumEle = ssc.sc().accumulator(0);
+ final Accumulator<Integer> accumRdd = ssc.sparkContext().accumulator(0);
+ final Accumulator<Integer> accumEle = ssc.sparkContext().accumulator(0);
List<List<Integer>> inputData = Arrays.asList(
Arrays.asList(1,1,1),
Arrays.asList(1,1,1));
@@ -375,7 +377,9 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ
});
// This is a test to make sure foreachRDD(VoidFunction2) can be called from Java
- stream.foreachRDD((rdd, time) -> null);
+ stream.foreachRDD((rdd, time) -> {
+ return;
+ });
JavaTestUtils.runStreams(ssc, 2, 2);
@@ -423,7 +427,7 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ
for (String letter : s.split("(?!^)")) {
out.add(new Tuple2<>(s.length(), letter));
}
- return out;
+ return out.iterator();
});
JavaTestUtils.attachTestOutputStream(flatMapped);
@@ -541,7 +545,7 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ
Tuple2<String, Integer> next = in.next();
out.add(next.swap());
}
- return out;
+ return out.iterator();
});
JavaTestUtils.attachTestOutputStream(reversed);
@@ -598,7 +602,7 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ
for (Character s : in._1().toCharArray()) {
out.add(new Tuple2<>(in._2(), s.toString()));
}
- return out;
+ return out.iterator();
});
JavaTestUtils.attachTestOutputStream(flatMapped);
@@ -871,7 +875,7 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ
JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream =
wordsDstream.mapWithState(
- StateSpec.<String, Integer, Boolean, Double> function((time, key, value, state) -> {
+ StateSpec.<String, Integer, Boolean, Double>function((time, key, value, state) -> {
// Use all State's methods here
state.exists();
state.get();
diff --git a/external/java8-tests/src/test/resources/log4j.properties b/external/java8-tests/src/test/resources/log4j.properties
index eb3b1999eb..3706a6e361 100644
--- a/external/java8-tests/src/test/resources/log4j.properties
+++ b/external/java8-tests/src/test/resources/log4j.properties
@@ -24,5 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
-log4j.logger.org.spark-project.jetty=WARN
-org.spark-project.jetty.LEVEL=WARN
+log4j.logger.org.spark_project.jetty=WARN
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala
index a660d2a00c..02917becf0 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala
@@ -19,13 +19,14 @@ package org.apache.spark.streaming.kafka
import org.apache.spark.Partition
-/** @param topic kafka topic name
- * @param partition kafka partition id
- * @param fromOffset inclusive starting offset
- * @param untilOffset exclusive ending offset
- * @param host preferred kafka host, i.e. the leader at the time the rdd was created
- * @param port preferred kafka host's port
- */
+/**
+ * @param topic kafka topic name
+ * @param partition kafka partition id
+ * @param fromOffset inclusive starting offset
+ * @param untilOffset exclusive ending offset
+ * @param host preferred kafka host, i.e. the leader at the time the rdd was created
+ * @param port preferred kafka host's port
+ */
private[kafka]
class KafkaRDDPartition(
val index: Int,
diff --git a/external/kafka/src/test/resources/log4j.properties b/external/kafka/src/test/resources/log4j.properties
index 75e3b53a09..fd51f8faf5 100644
--- a/external/kafka/src/test/resources/log4j.properties
+++ b/external/kafka/src/test/resources/log4j.properties
@@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
-log4j.logger.org.spark-project.jetty=WARN
+log4j.logger.org.spark_project.jetty=WARN
diff --git a/external/kinesis-asl/src/main/resources/log4j.properties b/external/kinesis-asl/src/main/resources/log4j.properties
index 6cdc9286c5..8118d12c5d 100644
--- a/external/kinesis-asl/src/main/resources/log4j.properties
+++ b/external/kinesis-asl/src/main/resources/log4j.properties
@@ -31,7 +31,7 @@ log4j.appender.console.layout=org.apache.log4j.PatternLayout
log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
# Settings to quiet third party logs that are too verbose
-log4j.logger.org.spark-project.jetty=WARN
-log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR
+log4j.logger.org.spark_project.jetty=WARN
+log4j.logger.org.spark_project.jetty.util.component.AbstractLifeCycle=ERROR
log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO \ No newline at end of file
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
index 41c6ab123b..80e0cce055 100644
--- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
@@ -73,7 +73,7 @@ private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], w
logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId")
receiver.setCheckpointer(shardId, checkpointer)
} catch {
- case NonFatal(e) => {
+ case NonFatal(e) =>
/*
* If there is a failure within the batch, the batch will not be checkpointed.
* This will potentially cause records since the last checkpoint to be processed
@@ -84,7 +84,6 @@ private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], w
/* Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor. */
throw e
- }
}
} else {
/* RecordProcessor has been stopped. */
@@ -148,29 +147,25 @@ private[kinesis] object KinesisRecordProcessor extends Logging {
/* If the function failed, either retry or throw the exception */
case util.Failure(e) => e match {
/* Retry: Throttling or other Retryable exception has occurred */
- case _: ThrottlingException | _: KinesisClientLibDependencyException if numRetriesLeft > 1
- => {
- val backOffMillis = Random.nextInt(maxBackOffMillis)
- Thread.sleep(backOffMillis)
- logError(s"Retryable Exception: Random backOffMillis=${backOffMillis}", e)
- retryRandom(expression, numRetriesLeft - 1, maxBackOffMillis)
- }
+ case _: ThrottlingException | _: KinesisClientLibDependencyException
+ if numRetriesLeft > 1 =>
+ val backOffMillis = Random.nextInt(maxBackOffMillis)
+ Thread.sleep(backOffMillis)
+ logError(s"Retryable Exception: Random backOffMillis=${backOffMillis}", e)
+ retryRandom(expression, numRetriesLeft - 1, maxBackOffMillis)
/* Throw: Shutdown has been requested by the Kinesis Client Library. */
- case _: ShutdownException => {
+ case _: ShutdownException =>
logError(s"ShutdownException: Caught shutdown exception, skipping checkpoint.", e)
throw e
- }
/* Throw: Non-retryable exception has occurred with the Kinesis Client Library */
- case _: InvalidStateException => {
+ case _: InvalidStateException =>
logError(s"InvalidStateException: Cannot save checkpoint to the DynamoDB table used" +
s" by the Amazon Kinesis Client Library. Table likely doesn't exist.", e)
throw e
- }
/* Throw: Unexpected exception has occurred */
- case _ => {
+ case _ =>
logError(s"Unexpected, non-retryable exception.", e)
throw e
- }
}
}
}
diff --git a/external/kinesis-asl/src/test/resources/log4j.properties b/external/kinesis-asl/src/test/resources/log4j.properties
index edbecdae92..3706a6e361 100644
--- a/external/kinesis-asl/src/test/resources/log4j.properties
+++ b/external/kinesis-asl/src/test/resources/log4j.properties
@@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
-log4j.logger.org.spark-project.jetty=WARN
+log4j.logger.org.spark_project.jetty=WARN
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
index fcb1b5999f..868658dfe5 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
@@ -276,7 +276,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
if (Random.nextDouble() < probability) { Some(vidVvals._1) }
else { None }
}
- if (selectedVertices.count > 1) {
+ if (selectedVertices.count > 0) {
found = true
val collectedVertices = selectedVertices.collect()
retVal = collectedVertices(Random.nextInt(collectedVertices.length))
@@ -415,11 +415,11 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
}
/**
- * Compute the connected component membership of each vertex and return a graph with the vertex
- * value containing the lowest vertex id in the connected component containing that vertex.
- *
- * @see [[org.apache.spark.graphx.lib.ConnectedComponents$#run]]
- */
+ * Compute the connected component membership of each vertex and return a graph with the vertex
+ * value containing the lowest vertex id in the connected component containing that vertex.
+ *
+ * @see [[org.apache.spark.graphx.lib.ConnectedComponents$#run]]
+ */
def connectedComponents(): Graph[VertexId, ED] = {
ConnectedComponents.run(graph)
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
index d2e51d2ec4..646462b4a8 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
@@ -119,7 +119,7 @@ object Pregel extends Logging {
mergeMsg: (A, A) => A)
: Graph[VD, ED] =
{
- require(maxIterations > 0, s"Maximum of iterations must be greater than 0," +
+ require(maxIterations > 0, s"Maximum number of iterations must be greater than 0," +
s" but got ${maxIterations}")
var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache()
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala
index 137c512c99..4e9b13162e 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala
@@ -60,15 +60,15 @@ object ConnectedComponents {
} // end of connectedComponents
/**
- * Compute the connected component membership of each vertex and return a graph with the vertex
- * value containing the lowest vertex id in the connected component containing that vertex.
- *
- * @tparam VD the vertex attribute type (discarded in the computation)
- * @tparam ED the edge attribute type (preserved in the computation)
- * @param graph the graph for which to compute the connected components
- * @return a graph with vertex attributes containing the smallest vertex in each
- * connected component
- */
+ * Compute the connected component membership of each vertex and return a graph with the vertex
+ * value containing the lowest vertex id in the connected component containing that vertex.
+ *
+ * @tparam VD the vertex attribute type (discarded in the computation)
+ * @tparam ED the edge attribute type (preserved in the computation)
+ * @param graph the graph for which to compute the connected components
+ * @return a graph with vertex attributes containing the smallest vertex in each
+ * connected component
+ */
def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[VertexId, ED] = {
run(graph, Int.MaxValue)
}
diff --git a/graphx/src/test/resources/log4j.properties b/graphx/src/test/resources/log4j.properties
index eb3b1999eb..3706a6e361 100644
--- a/graphx/src/test/resources/log4j.properties
+++ b/graphx/src/test/resources/log4j.properties
@@ -24,5 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
-log4j.logger.org.spark-project.jetty=WARN
-org.spark-project.jetty.LEVEL=WARN
+log4j.logger.org.spark_project.jetty=WARN
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
index cb981797d3..96aa262a39 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
@@ -404,4 +404,13 @@ class GraphSuite extends SparkFunSuite with LocalSparkContext {
assert(sc.getPersistentRDDs.isEmpty)
}
}
+
+ test("SPARK-14219: pickRandomVertex") {
+ withSpark { sc =>
+ val vert = sc.parallelize(List((1L, "a")), 1)
+ val edges = sc.parallelize(List(Edge[Long](1L, 1L)), 1)
+ val g0 = Graph(vert, edges)
+ assert(g0.pickRandomVertex() === 1L)
+ }
+ }
}
diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
index 587fda7a3c..c7488082ca 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
@@ -74,7 +74,8 @@ abstract class AbstractCommandBuilder {
* SparkLauncher constructor that takes an environment), and may be modified to
* include other variables needed by the process to be executed.
*/
- abstract List<String> buildCommand(Map<String, String> env) throws IOException;
+ abstract List<String> buildCommand(Map<String, String> env)
+ throws IOException, IllegalArgumentException;
/**
* Builds a list of arguments to run java.
@@ -144,10 +145,26 @@ abstract class AbstractCommandBuilder {
boolean isTesting = "1".equals(getenv("SPARK_TESTING"));
if (prependClasses || isTesting) {
String scala = getScalaVersion();
- List<String> projects = Arrays.asList("core", "repl", "mllib", "graphx",
- "streaming", "tools", "sql/catalyst", "sql/core", "sql/hive", "sql/hive-thriftserver",
- "yarn", "launcher",
- "common/network-common", "common/network-shuffle", "common/network-yarn");
+ List<String> projects = Arrays.asList(
+ "common/network-common",
+ "common/network-shuffle",
+ "common/network-yarn",
+ "common/sketch",
+ "common/tags",
+ "common/unsafe",
+ "core",
+ "examples",
+ "graphx",
+ "launcher",
+ "mllib",
+ "repl",
+ "sql/catalyst",
+ "sql/core",
+ "sql/hive",
+ "sql/hive-thriftserver",
+ "streaming",
+ "yarn"
+ );
if (prependClasses) {
if (!isTesting) {
System.err.println(
@@ -174,31 +191,12 @@ abstract class AbstractCommandBuilder {
// Add Spark jars to the classpath. For the testing case, we rely on the test code to set and
// propagate the test classpath appropriately. For normal invocation, look for the jars
// directory under SPARK_HOME.
- String jarsDir = findJarsDir(!isTesting);
+ boolean isTestingSql = "1".equals(getenv("SPARK_SQL_TESTING"));
+ String jarsDir = findJarsDir(getSparkHome(), getScalaVersion(), !isTesting && !isTestingSql);
if (jarsDir != null) {
addToClassPath(cp, join(File.separator, jarsDir, "*"));
}
- // Datanucleus jars must be included on the classpath. Datanucleus jars do not work if only
- // included in the uber jar as plugin.xml metadata is lost. Both sbt and maven will populate
- // "lib_managed/jars/" with the datanucleus jars when Spark is built with Hive
- File libdir;
- if (new File(sparkHome, "RELEASE").isFile()) {
- libdir = new File(sparkHome, "lib");
- } else {
- libdir = new File(sparkHome, "lib_managed/jars");
- }
-
- if (libdir.isDirectory()) {
- for (File jar : libdir.listFiles()) {
- if (jar.getName().startsWith("datanucleus-")) {
- addToClassPath(cp, jar.getAbsolutePath());
- }
- }
- } else {
- checkState(isTesting, "Library directory '%s' does not exist.", libdir.getAbsolutePath());
- }
-
addToClassPath(cp, getenv("HADOOP_CONF_DIR"));
addToClassPath(cp, getenv("YARN_CONF_DIR"));
addToClassPath(cp, getenv("SPARK_DIST_CLASSPATH"));
@@ -311,27 +309,6 @@ abstract class AbstractCommandBuilder {
return props;
}
- private String findJarsDir(boolean failIfNotFound) {
- // TODO: change to the correct directory once the assembly build is changed.
- String sparkHome = getSparkHome();
- File libdir;
- if (new File(sparkHome, "RELEASE").isFile()) {
- libdir = new File(sparkHome, "lib");
- checkState(!failIfNotFound || libdir.isDirectory(),
- "Library directory '%s' does not exist.",
- libdir.getAbsolutePath());
- } else {
- libdir = new File(sparkHome, String.format("assembly/target/scala-%s", getScalaVersion()));
- if (!libdir.isDirectory()) {
- checkState(!failIfNotFound,
- "Library directory '%s' does not exist; make sure Spark is built.",
- libdir.getAbsolutePath());
- libdir = null;
- }
- }
- return libdir != null ? libdir.getAbsolutePath() : null;
- }
-
private String getConfDir() {
String confDir = getenv("SPARK_CONF_DIR");
return confDir != null ? confDir : join(File.separator, getSparkHome(), "conf");
diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java
index 39fdf300e2..91586aad7b 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java
@@ -34,7 +34,7 @@ class CommandBuilderUtils {
/** The set of known JVM vendors. */
enum JavaVendor {
Oracle, IBM, OpenJDK, Unknown
- };
+ }
/** Returns whether the given string is null or empty. */
static boolean isEmpty(String s) {
@@ -349,4 +349,29 @@ class CommandBuilderUtils {
return Integer.parseInt(version[1]);
}
}
+
+ /**
+ * Find the location of the Spark jars dir, depending on whether we're looking at a build
+ * or a distribution directory.
+ */
+ static String findJarsDir(String sparkHome, String scalaVersion, boolean failIfNotFound) {
+ // TODO: change to the correct directory once the assembly build is changed.
+ File libdir;
+ if (new File(sparkHome, "RELEASE").isFile()) {
+ libdir = new File(sparkHome, "jars");
+ checkState(!failIfNotFound || libdir.isDirectory(),
+ "Library directory '%s' does not exist.",
+ libdir.getAbsolutePath());
+ } else {
+ libdir = new File(sparkHome, String.format("assembly/target/scala-%s/jars", scalaVersion));
+ if (!libdir.isDirectory()) {
+ checkState(!failIfNotFound,
+ "Library directory '%s' does not exist; make sure Spark is built.",
+ libdir.getAbsolutePath());
+ libdir = null;
+ }
+ }
+ return libdir != null ? libdir.getAbsolutePath() : null;
+ }
+
}
diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java
index 6b9d36cc0b..82b593a3f7 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java
@@ -41,7 +41,8 @@ class SparkClassCommandBuilder extends AbstractCommandBuilder {
}
@Override
- public List<String> buildCommand(Map<String, String> env) throws IOException {
+ public List<String> buildCommand(Map<String, String> env)
+ throws IOException, IllegalArgumentException {
List<String> javaOptsKeys = new ArrayList<>();
String memKey = null;
String extraClassPath = null;
@@ -80,12 +81,18 @@ class SparkClassCommandBuilder extends AbstractCommandBuilder {
}
List<String> cmd = buildJavaCommand(extraClassPath);
+
for (String key : javaOptsKeys) {
- addOptionString(cmd, System.getenv(key));
+ String envValue = System.getenv(key);
+ if (!isEmpty(envValue) && envValue.contains("Xmx")) {
+ String msg = String.format("%s is not allowed to specify max heap(Xmx) memory settings " +
+ "(was %s). Use the corresponding configuration instead.", key, envValue);
+ throw new IllegalArgumentException(msg);
+ }
+ addOptionString(cmd, envValue);
}
String mem = firstNonEmpty(memKey != null ? System.getenv(memKey) : null, DEFAULT_MEM);
- cmd.add("-Xms" + mem);
cmd.add("-Xmx" + mem);
addPermGenSizeOpt(cmd);
cmd.add(className);
diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java
index a542159901..a083f05a2a 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java
@@ -477,6 +477,6 @@ public class SparkLauncher {
// No op.
}
- };
+ }
}
diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java
index 56e4107c5a..6941ca903c 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java
@@ -132,7 +132,8 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder {
}
@Override
- public List<String> buildCommand(Map<String, String> env) throws IOException {
+ public List<String> buildCommand(Map<String, String> env)
+ throws IOException, IllegalArgumentException {
if (PYSPARK_SHELL_RESOURCE.equals(appResource) && !printInfo) {
return buildPySparkShellCommand(env);
} else if (SPARKR_SHELL_RESOURCE.equals(appResource) && !printInfo) {
@@ -211,7 +212,8 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder {
return args;
}
- private List<String> buildSparkSubmitCommand(Map<String, String> env) throws IOException {
+ private List<String> buildSparkSubmitCommand(Map<String, String> env)
+ throws IOException, IllegalArgumentException {
// Load the properties file and check whether spark-submit will be running the app's driver
// or just launching a cluster app. When running the driver, the JVM's argument will be
// modified to cover the driver's configuration.
@@ -227,6 +229,16 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder {
addOptionString(cmd, System.getenv("SPARK_SUBMIT_OPTS"));
addOptionString(cmd, System.getenv("SPARK_JAVA_OPTS"));
+ // We don't want the client to specify Xmx. These have to be set by their corresponding
+ // memory flag --driver-memory or configuration entry spark.driver.memory
+ String driverExtraJavaOptions = config.get(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS);
+ if (!isEmpty(driverExtraJavaOptions) && driverExtraJavaOptions.contains("Xmx")) {
+ String msg = String.format("Not allowed to specify max heap(Xmx) memory settings through " +
+ "java options (was %s). Use the corresponding --driver-memory or " +
+ "spark.driver.memory configuration instead.", driverExtraJavaOptions);
+ throw new IllegalArgumentException(msg);
+ }
+
if (isClientMode) {
// Figuring out where the memory value come from is a little tricky due to precedence.
// Precedence is observed in the following order:
@@ -240,9 +252,8 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder {
isThriftServer(mainClass) ? System.getenv("SPARK_DAEMON_MEMORY") : null;
String memory = firstNonEmpty(tsMemory, config.get(SparkLauncher.DRIVER_MEMORY),
System.getenv("SPARK_DRIVER_MEMORY"), System.getenv("SPARK_MEM"), DEFAULT_MEM);
- cmd.add("-Xms" + memory);
cmd.add("-Xmx" + memory);
- addOptionString(cmd, config.get(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS));
+ addOptionString(cmd, driverExtraJavaOptions);
mergeEnvPathList(env, getLibPathEnvName(),
config.get(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH));
}
@@ -336,6 +347,7 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder {
}
private List<String> findExamplesJars() {
+ boolean isTesting = "1".equals(getenv("SPARK_TESTING"));
List<String> examplesJars = new ArrayList<>();
String sparkHome = getSparkHome();
@@ -346,11 +358,15 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder {
jarsDir = new File(sparkHome,
String.format("examples/target/scala-%s/jars", getScalaVersion()));
}
- checkState(jarsDir.isDirectory(), "Examples jars directory '%s' does not exist.",
+
+ boolean foundDir = jarsDir.isDirectory();
+ checkState(isTesting || foundDir, "Examples jars directory '%s' does not exist.",
jarsDir.getAbsolutePath());
- for (File f: jarsDir.listFiles()) {
- examplesJars.add(f.getAbsolutePath());
+ if (foundDir) {
+ for (File f: jarsDir.listFiles()) {
+ examplesJars.add(f.getAbsolutePath());
+ }
}
return examplesJars;
}
diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java
index 5bf2babdd1..bfe1fcc87f 100644
--- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java
+++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java
@@ -83,13 +83,13 @@ public class LauncherServerSuite extends BaseSuite {
client = new TestClient(s);
client.send(new Hello(handle.getSecret(), "1.4.0"));
- assertTrue(semaphore.tryAcquire(1, TimeUnit.SECONDS));
+ assertTrue(semaphore.tryAcquire(30, TimeUnit.SECONDS));
// Make sure the server matched the client to the handle.
assertNotNull(handle.getConnection());
client.send(new SetAppId("app-id"));
- assertTrue(semaphore.tryAcquire(1, TimeUnit.SECONDS));
+ assertTrue(semaphore.tryAcquire(30, TimeUnit.SECONDS));
assertEquals("app-id", handle.getAppId());
client.send(new SetState(SparkAppHandle.State.RUNNING));
@@ -97,7 +97,7 @@ public class LauncherServerSuite extends BaseSuite {
assertEquals(SparkAppHandle.State.RUNNING, handle.getState());
handle.stop();
- Message stopMsg = client.inbound.poll(10, TimeUnit.SECONDS);
+ Message stopMsg = client.inbound.poll(30, TimeUnit.SECONDS);
assertTrue(stopMsg instanceof Stop);
} finally {
kill(handle);
@@ -175,7 +175,7 @@ public class LauncherServerSuite extends BaseSuite {
TestClient(Socket s) throws IOException {
super(s);
- this.inbound = new LinkedBlockingQueue<Message>();
+ this.inbound = new LinkedBlockingQueue<>();
this.clientThread = new Thread(this);
clientThread.setName("TestClient");
clientThread.setDaemon(true);
diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java
index b7f4f2efc5..c7e8b2e03a 100644
--- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java
+++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java
@@ -79,7 +79,6 @@ public class SparkSubmitCommandBuilderSuite extends BaseSuite {
assertTrue(findInStringList(env.get(CommandBuilderUtils.getLibPathEnvName()),
File.pathSeparator, "/driverLibPath"));
assertTrue(findInStringList(findArgValue(cmd, "-cp"), File.pathSeparator, "/driverCp"));
- assertTrue("Driver -Xms should be configured.", cmd.contains("-Xms42g"));
assertTrue("Driver -Xmx should be configured.", cmd.contains("-Xmx42g"));
assertTrue("Command should contain user-defined conf.",
Collections.indexOfSubList(cmd, Arrays.asList(parser.CONF, "spark.randomOption=foo")) > 0);
@@ -160,7 +159,7 @@ public class SparkSubmitCommandBuilderSuite extends BaseSuite {
"SparkPi",
"42");
- Map<String, String> env = new HashMap<String, String>();
+ Map<String, String> env = new HashMap<>();
List<String> cmd = buildCommand(sparkSubmitArgs, env);
assertEquals("foo", findArgValue(cmd, parser.MASTER));
assertEquals("bar", findArgValue(cmd, parser.DEPLOY_MODE));
@@ -202,12 +201,11 @@ public class SparkSubmitCommandBuilderSuite extends BaseSuite {
// Checks below are different for driver and non-driver mode.
if (isDriver) {
- assertTrue("Driver -Xms should be configured.", cmd.contains("-Xms1g"));
assertTrue("Driver -Xmx should be configured.", cmd.contains("-Xmx1g"));
} else {
boolean found = false;
for (String arg : cmd) {
- if (arg.startsWith("-Xms") || arg.startsWith("-Xmx")) {
+ if (arg.startsWith("-Xmx")) {
found = true;
break;
}
diff --git a/launcher/src/test/resources/log4j.properties b/launcher/src/test/resources/log4j.properties
index c64b1565e1..744c456cb2 100644
--- a/launcher/src/test/resources/log4j.properties
+++ b/launcher/src/test/resources/log4j.properties
@@ -30,5 +30,4 @@ log4j.appender.childproc.layout=org.apache.log4j.PatternLayout
log4j.appender.childproc.layout.ConversionPattern=%t: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
-log4j.logger.org.spark-project.jetty=WARN
-org.spark-project.jetty.LEVEL=WARN
+log4j.logger.org.spark_project.jetty=WARN
diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml
new file mode 100644
index 0000000000..68f15dd905
--- /dev/null
+++ b/mllib-local/pom.xml
@@ -0,0 +1,74 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-parent_2.11</artifactId>
+ <version>2.0.0-SNAPSHOT</version>
+ <relativePath>../pom.xml</relativePath>
+ </parent>
+
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-mllib-local_2.11</artifactId>
+ <properties>
+ <sbt.project.name>mllib-local</sbt.project.name>
+ </properties>
+ <packaging>jar</packaging>
+ <name>Spark Project ML Local Library</name>
+ <url>http://spark.apache.org/</url>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.scalanlp</groupId>
+ <artifactId>breeze_${scala.binary.version}</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.commons</groupId>
+ <artifactId>commons-math3</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.scalacheck</groupId>
+ <artifactId>scalacheck_${scala.binary.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-core</artifactId>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+ <profiles>
+ <profile>
+ <id>netlib-lgpl</id>
+ <dependencies>
+ <dependency>
+ <groupId>com.github.fommil.netlib</groupId>
+ <artifactId>all</artifactId>
+ <version>${netlib.java.version}</version>
+ <type>pom</type>
+ </dependency>
+ </dependencies>
+ </profile>
+ </profiles>
+ <build>
+ <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
+ <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
+ </build>
+</project>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala b/mllib-local/src/main/scala/org/apache/spark/ml/DummyTesting.scala
index ce449b1143..6b3268cdfa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/DummyTesting.scala
@@ -14,13 +14,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.sql.catalyst.parser
-trait ParserConf {
- def supportQuotedId: Boolean
- def supportSQL11ReservedKeywords: Boolean
-}
+package org.apache.spark.ml
-case class SimpleParserConf(
- supportQuotedId: Boolean = true,
- supportSQL11ReservedKeywords: Boolean = false) extends ParserConf
+// This is a private class testing if the new build works. To be removed soon.
+private[ml] object DummyTesting {
+ private[ml] def add10(input: Double): Double = input + 10
+}
diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/DummyTestingSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/DummyTestingSuite.scala
new file mode 100644
index 0000000000..51b7c2409f
--- /dev/null
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/DummyTestingSuite.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import org.scalatest.FunSuite // scalastyle:ignore funsuite
+
+// This is testing if the new build works. To be removed soon.
+class DummyTestingSuite extends FunSuite { // scalastyle:ignore funsuite
+
+ test("This is testing if the new build works.") {
+ assert(DummyTesting.add10(15) === 25)
+ }
+}
diff --git a/mllib/pom.xml b/mllib/pom.xml
index 428176dcbf..24d8274e22 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -63,21 +63,20 @@
<version>${project.version}</version>
</dependency>
<dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-mllib-local_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-mllib-local_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>org.scalanlp</groupId>
<artifactId>breeze_${scala.binary.version}</artifactId>
- <version>0.11.2</version>
- <exclusions>
- <!-- This is included as a compile-scoped dependency by jtransforms, which is
- a dependency of breeze. -->
- <exclusion>
- <groupId>junit</groupId>
- <artifactId>junit</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.apache.commons</groupId>
- <artifactId>commons-math3</artifactId>
- </exclusion>
- </exclusions>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
index 57e416591d..1247882d6c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -19,9 +19,9 @@ package org.apache.spark.ml
import scala.annotation.varargs
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.ml.param.{ParamMap, ParamPair}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.Dataset
/**
* :: DeveloperApi ::
@@ -39,8 +39,9 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
* Estimator's embedded ParamMap.
* @return fitted model
*/
+ @Since("2.0.0")
@varargs
- def fit(dataset: DataFrame, firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = {
+ def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = {
val map = new ParamMap()
.put(firstParamPair)
.put(otherParamPairs: _*)
@@ -55,14 +56,16 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
* These values override any specified in this Estimator's embedded ParamMap.
* @return fitted model
*/
- def fit(dataset: DataFrame, paramMap: ParamMap): M = {
+ @Since("2.0.0")
+ def fit(dataset: Dataset[_], paramMap: ParamMap): M = {
copy(paramMap).fit(dataset)
}
/**
* Fits a model to the input data.
*/
- def fit(dataset: DataFrame): M
+ @Since("2.0.0")
+ def fit(dataset: Dataset[_]): M
/**
* Fits multiple models to the input data with multiple sets of parameters.
@@ -74,7 +77,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
* These values override any specified in this Estimator's embedded ParamMap.
* @return fitted models, matching the input parameter maps
*/
- def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = {
+ @Since("2.0.0")
+ def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[M] = {
paramMaps.map(fit(dataset, _))
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 3a99979a88..82066726a0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -31,7 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.ml.util._
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
/**
@@ -123,8 +123,8 @@ class Pipeline @Since("1.4.0") (
* @param dataset input dataset
* @return fitted pipeline
*/
- @Since("1.2.0")
- override def fit(dataset: DataFrame): PipelineModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): PipelineModel = {
transformSchema(dataset.schema, logging = true)
val theStages = $(stages)
// Search for the last estimator.
@@ -147,7 +147,7 @@ class Pipeline @Since("1.4.0") (
t
case _ =>
throw new IllegalArgumentException(
- s"Do not support stage $stage of type ${stage.getClass}")
+ s"Does not support stage $stage of type ${stage.getClass}")
}
if (index < indexOfLastEstimator) {
curDataset = transformer.transform(curDataset)
@@ -291,10 +291,10 @@ class PipelineModel private[ml] (
this(uid, stages.asScala.toArray)
}
- @Since("1.2.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
- stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur))
+ stages.foldLeft(dataset.toDF)((cur, transformer) => transformer.transform(cur))
}
@Since("1.2.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index ebe48700f8..81140d1f7b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
@@ -36,6 +36,7 @@ private[ml] trait PredictorParams extends Params
/**
* Validates and transforms the input schema with the provided param map.
+ *
* @param schema input schema
* @param fitting whether this is in fitting
* @param featuresDataType SQL DataType for FeaturesType.
@@ -49,8 +50,7 @@ private[ml] trait PredictorParams extends Params
// TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector
SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType)
if (fitting) {
- // TODO: Allow other numeric types
- SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+ SchemaUtils.checkNumericType(schema, $(labelCol))
}
SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
}
@@ -83,7 +83,7 @@ abstract class Predictor[
/** @group setParam */
def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner]
- override def fit(dataset: DataFrame): M = {
+ override def fit(dataset: Dataset[_]): M = {
// This handles a few items such as schema validation.
// Developers only need to implement train().
transformSchema(dataset.schema, logging = true)
@@ -100,7 +100,7 @@ abstract class Predictor[
* @param dataset Training dataset
* @return Fitted model
*/
- protected def train(dataset: DataFrame): M
+ protected def train(dataset: Dataset[_]): M
/**
* Returns the SQL DataType corresponding to the FeaturesType type parameter.
@@ -120,10 +120,9 @@ abstract class Predictor[
* Extract [[labelCol]] and [[featuresCol]] from the given dataset,
* and put it in an RDD with strong types.
*/
- protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = {
- dataset.select($(labelCol), $(featuresCol)).rdd.map {
- case Row(label: Double, features: Vector) =>
- LabeledPoint(label, features)
+ protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = {
+ dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
+ case Row(label: Double, features: Vector) => LabeledPoint(label, features)
}
}
}
@@ -172,18 +171,18 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
* @param dataset input dataset
* @return transformed dataset with [[predictionCol]] of type [[Double]]
*/
- override def transform(dataset: DataFrame): DataFrame = {
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
if ($(predictionCol).nonEmpty) {
transformImpl(dataset)
} else {
this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
" since no output columns were set.")
- dataset
+ dataset.toDF
}
}
- protected def transformImpl(dataset: DataFrame): DataFrame = {
+ protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val predictUDF = udf { (features: Any) =>
predict(features.asInstanceOf[FeaturesType])
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index 2538c0f477..a3a2b55adc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -19,11 +19,11 @@ package org.apache.spark.ml
import scala.annotation.varargs
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
@@ -41,9 +41,10 @@ abstract class Transformer extends PipelineStage {
* @param otherParamPairs other param pairs, overwrite embedded params
* @return transformed dataset
*/
+ @Since("2.0.0")
@varargs
def transform(
- dataset: DataFrame,
+ dataset: Dataset[_],
firstParamPair: ParamPair[_],
otherParamPairs: ParamPair[_]*): DataFrame = {
val map = new ParamMap()
@@ -58,14 +59,16 @@ abstract class Transformer extends PipelineStage {
* @param paramMap additional parameters, overwrite embedded params
* @return transformed dataset
*/
- def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ @Since("2.0.0")
+ def transform(dataset: Dataset[_], paramMap: ParamMap): DataFrame = {
this.copy(paramMap).transform(dataset)
}
/**
* Transforms the input dataset.
*/
- def transform(dataset: DataFrame): DataFrame
+ @Since("2.0.0")
+ def transform(dataset: Dataset[_]): DataFrame
override def copy(extra: ParamMap): Transformer
}
@@ -113,7 +116,7 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
StructType(outputFields)
}
- override def transform(dataset: DataFrame): DataFrame = {
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val transformUDF = udf(this.createTransformFunc, outputDataType)
dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
index 2cd94fa8f5..a5b84116e6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
@@ -17,9 +17,9 @@
package org.apache.spark.ml.ann
-import breeze.linalg.{*, axpy => Baxpy, sum => Bsum, DenseMatrix => BDM, DenseVector => BDV,
- Vector => BV}
-import breeze.numerics.{log => Blog, sigmoid => Bsigmoid}
+import java.util.Random
+
+import breeze.linalg.{*, axpy => Baxpy, DenseMatrix => BDM, DenseVector => BDV, Vector => BV}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.optimization._
@@ -32,20 +32,46 @@ import org.apache.spark.util.random.XORShiftRandom
*
*/
private[ann] trait Layer extends Serializable {
+
/**
- * Returns the instance of the layer based on weights provided
- * @param weights vector with layer weights
- * @param position position of weights in the vector
- * @return the layer model
+ * Number of weights that is used to allocate memory for the weights vector
+ */
+ val weightSize: Int
+
+ /**
+ * Returns the output size given the input size (not counting the stack size).
+ * Output size is used to allocate memory for the output.
+ *
+ * @param inputSize input size
+ * @return output size
*/
- def getInstance(weights: Vector, position: Int): LayerModel
+ def getOutputSize(inputSize: Int): Int
/**
+ * If true, the memory is not allocated for the output of this layer.
+ * The memory allocated to the previous layer is used to write the output of this layer.
+ * Developer can set this to true if computing delta of a previous layer
+ * does not involve its output, so the current layer can write there.
+ * This also mean that both layers have the same number of outputs.
+ */
+ val inPlace: Boolean
+
+ /**
+ * Returns the instance of the layer based on weights provided.
+ * Size of weights must be equal to weightSize
+ *
+ * @param initialWeights vector with layer weights
+ * @return the layer model
+ */
+ def createModel(initialWeights: BDV[Double]): LayerModel
+ /**
* Returns the instance of the layer with random generated weights
- * @param seed seed
+ *
+ * @param weights vector for weights initialization, must be equal to weightSize
+ * @param random random number generator
* @return the layer model
*/
- def getInstance(seed: Long): LayerModel
+ def initModel(weights: BDV[Double], random: Random): LayerModel
}
/**
@@ -54,92 +80,102 @@ private[ann] trait Layer extends Serializable {
* Can return weights in Vector format.
*/
private[ann] trait LayerModel extends Serializable {
- /**
- * number of weights
- */
- val size: Int
+ val weights: BDV[Double]
/**
* Evaluates the data (process the data through the layer)
+ * Output is allocated based on the size provided by the
+ * LayerModel implementation and the stack (batch) size
+ * Developer is responsible for checking the size of output
+ * when writing to it
+ *
* @param data data
- * @return processed data
+ * @param output output (modified in place)
*/
- def eval(data: BDM[Double]): BDM[Double]
+ def eval(data: BDM[Double], output: BDM[Double]): Unit
/**
* Computes the delta for back propagation
- * @param nextDelta delta of the next layer
- * @param input input data
- * @return delta
+ * Delta is allocated based on the size provided by the
+ * LayerModel implementation and the stack (batch) size
+ * Developer is responsible for checking the size of
+ * prevDelta when writing to it
+ *
+ * @param delta delta of this layer
+ * @param output output of this layer
+ * @param prevDelta the previous delta (modified in place)
*/
- def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double]
+ def computePrevDelta(delta: BDM[Double], output: BDM[Double], prevDelta: BDM[Double]): Unit
/**
* Computes the gradient
+ * cumGrad is a wrapper on the part of the weight vector
+ * size of cumGrad is based on weightSize provided by
+ * implementation of LayerModel
+ *
* @param delta delta for this layer
* @param input input data
- * @return gradient
+ * @param cumGrad cumulative gradient (modified in place)
*/
- def grad(delta: BDM[Double], input: BDM[Double]): Array[Double]
-
- /**
- * Returns weights for the layer in a single vector
- * @return layer weights
- */
- def weights(): Vector
+ def grad(delta: BDM[Double], input: BDM[Double], cumGrad: BDV[Double]): Unit
}
/**
* Layer properties of affine transformations, that is y=A*x+b
+ *
* @param numIn number of inputs
* @param numOut number of outputs
*/
private[ann] class AffineLayer(val numIn: Int, val numOut: Int) extends Layer {
- override def getInstance(weights: Vector, position: Int): LayerModel = {
- AffineLayerModel(this, weights, position)
- }
+ override val weightSize = numIn * numOut + numOut
- override def getInstance(seed: Long = 11L): LayerModel = {
- AffineLayerModel(this, seed)
- }
+ override def getOutputSize(inputSize: Int): Int = numOut
+
+ override val inPlace = false
+
+ override def createModel(weights: BDV[Double]): LayerModel = new AffineLayerModel(weights, this)
+
+ override def initModel(weights: BDV[Double], random: Random): LayerModel =
+ AffineLayerModel(this, weights, random)
}
/**
- * Model of Affine layer y=A*x+b
- * @param w weights (matrix A)
- * @param b bias (vector b)
+ * Model of Affine layer
+ *
+ * @param weights weights
+ * @param layer layer properties
*/
-private[ann] class AffineLayerModel private(w: BDM[Double], b: BDV[Double]) extends LayerModel {
- val size = w.size + b.length
- val gwb = new Array[Double](size)
- private lazy val gw: BDM[Double] = new BDM[Double](w.rows, w.cols, gwb)
- private lazy val gb: BDV[Double] = new BDV[Double](gwb, w.size)
- private var z: BDM[Double] = null
- private var d: BDM[Double] = null
+private[ann] class AffineLayerModel private[ann] (
+ val weights: BDV[Double],
+ val layer: AffineLayer) extends LayerModel {
+ val w = new BDM[Double](layer.numOut, layer.numIn, weights.data, weights.offset)
+ val b =
+ new BDV[Double](weights.data, weights.offset + (layer.numOut * layer.numIn), 1, layer.numOut)
+
private var ones: BDV[Double] = null
- override def eval(data: BDM[Double]): BDM[Double] = {
- if (z == null || z.cols != data.cols) z = new BDM[Double](w.rows, data.cols)
- z(::, *) := b
- BreezeUtil.dgemm(1.0, w, data, 1.0, z)
- z
+ override def eval(data: BDM[Double], output: BDM[Double]): Unit = {
+ output(::, *) := b
+ BreezeUtil.dgemm(1.0, w, data, 1.0, output)
}
- override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = {
- if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](w.cols, nextDelta.cols)
- BreezeUtil.dgemm(1.0, w.t, nextDelta, 0.0, d)
- d
+ override def computePrevDelta(
+ delta: BDM[Double],
+ output: BDM[Double],
+ prevDelta: BDM[Double]): Unit = {
+ BreezeUtil.dgemm(1.0, w.t, delta, 0.0, prevDelta)
}
- override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = {
- BreezeUtil.dgemm(1.0 / input.cols, delta, input.t, 0.0, gw)
+ override def grad(delta: BDM[Double], input: BDM[Double], cumGrad: BDV[Double]): Unit = {
+ // compute gradient of weights
+ val cumGradientOfWeights = new BDM[Double](w.rows, w.cols, cumGrad.data, cumGrad.offset)
+ BreezeUtil.dgemm(1.0 / input.cols, delta, input.t, 1.0, cumGradientOfWeights)
if (ones == null || ones.length != delta.cols) ones = BDV.ones[Double](delta.cols)
- BreezeUtil.dgemv(1.0 / input.cols, delta, ones, 0.0, gb)
- gwb
+ // compute gradient of bias
+ val cumGradientOfBias = new BDV[Double](cumGrad.data, cumGrad.offset + w.size, 1, b.length)
+ BreezeUtil.dgemv(1.0 / input.cols, delta, ones, 1.0, cumGradientOfBias)
}
-
- override def weights(): Vector = AffineLayerModel.roll(w, b)
}
/**
@@ -149,73 +185,40 @@ private[ann] object AffineLayerModel {
/**
* Creates a model of Affine layer
+ *
* @param layer layer properties
- * @param weights vector with weights
- * @param position position of weights in the vector
- * @return model of Affine layer
- */
- def apply(layer: AffineLayer, weights: Vector, position: Int): AffineLayerModel = {
- val (w, b) = unroll(weights, position, layer.numIn, layer.numOut)
- new AffineLayerModel(w, b)
- }
-
- /**
- * Creates a model of Affine layer
- * @param layer layer properties
- * @param seed seed
+ * @param weights vector for weights initialization
+ * @param random random number generator
* @return model of Affine layer
*/
- def apply(layer: AffineLayer, seed: Long): AffineLayerModel = {
- val (w, b) = randomWeights(layer.numIn, layer.numOut, seed)
- new AffineLayerModel(w, b)
- }
-
- /**
- * Unrolls the weights from the vector
- * @param weights vector with weights
- * @param position position of weights for this layer
- * @param numIn number of layer inputs
- * @param numOut number of layer outputs
- * @return matrix A and vector b
- */
- def unroll(
- weights: Vector,
- position: Int,
- numIn: Int,
- numOut: Int): (BDM[Double], BDV[Double]) = {
- val weightsCopy = weights.toArray
- // TODO: the array is not copied to BDMs, make sure this is OK!
- val a = new BDM[Double](numOut, numIn, weightsCopy, position)
- val b = new BDV[Double](weightsCopy, position + (numOut * numIn), 1, numOut)
- (a, b)
- }
-
- /**
- * Roll the layer weights into a vector
- * @param a matrix A
- * @param b vector b
- * @return vector of weights
- */
- def roll(a: BDM[Double], b: BDV[Double]): Vector = {
- val result = new Array[Double](a.size + b.length)
- // TODO: make sure that we need to copy!
- System.arraycopy(a.toArray, 0, result, 0, a.size)
- System.arraycopy(b.toArray, 0, result, a.size, b.length)
- Vectors.dense(result)
+ def apply(layer: AffineLayer, weights: BDV[Double], random: Random): AffineLayerModel = {
+ randomWeights(layer.numIn, layer.numOut, weights, random)
+ new AffineLayerModel(weights, layer)
}
/**
- * Generate random weights for the layer
- * @param numIn number of inputs
+ * Initialize weights randomly in the interval
+ * Uses [Bottou-88] heuristic [-a/sqrt(in); a/sqrt(in)]
+ * where a is chosen in a such way that the weight variance corresponds
+ * to the points to the maximal curvature of the activation function
+ * (which is approximately 2.38 for a standard sigmoid)
+ *
+ * @param numIn number of inputs
* @param numOut number of outputs
- * @param seed seed
- * @return (matrix A, vector b)
+ * @param weights vector for weights initialization
+ * @param random random number generator
*/
- def randomWeights(numIn: Int, numOut: Int, seed: Long = 11L): (BDM[Double], BDV[Double]) = {
- val rand: XORShiftRandom = new XORShiftRandom(seed)
- val weights = BDM.fill[Double](numOut, numIn) { (rand.nextDouble * 4.8 - 2.4) / numIn }
- val bias = BDV.fill[Double](numOut) { (rand.nextDouble * 4.8 - 2.4) / numIn }
- (weights, bias)
+ def randomWeights(
+ numIn: Int,
+ numOut: Int,
+ weights: BDV[Double],
+ random: Random): Unit = {
+ var i = 0
+ val sqrtIn = math.sqrt(numIn)
+ while (i < weights.length) {
+ weights(i) = (random.nextDouble * 4.8 - 2.4) / sqrtIn
+ i += 1
+ }
}
}
@@ -226,44 +229,21 @@ private[ann] trait ActivationFunction extends Serializable {
/**
* Implements a function
- * @param x input data
- * @param y output data
*/
- def eval(x: BDM[Double], y: BDM[Double]): Unit
+ def eval: Double => Double
/**
* Implements a derivative of a function (needed for the back propagation)
- * @param x input data
- * @param y output data
*/
- def derivative(x: BDM[Double], y: BDM[Double]): Unit
-
- /**
- * Implements a cross entropy error of a function.
- * Needed if the functional layer that contains this function is the output layer
- * of the network.
- * @param target target output
- * @param output computed output
- * @param result intermediate result
- * @return cross-entropy
- */
- def crossEntropy(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double
-
- /**
- * Implements a mean squared error of a function
- * @param target target output
- * @param output computed output
- * @param result intermediate result
- * @return mean squared error
- */
- def squared(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double
+ def derivative: Double => Double
}
/**
- * Implements in-place application of functions
+ * Implements in-place application of functions in the arrays
*/
-private[ann] object ActivationFunction {
+private[ann] object ApplyInPlace {
+ // TODO: use Breeze UFunc
def apply(x: BDM[Double], y: BDM[Double], func: Double => Double): Unit = {
var i = 0
while (i < x.rows) {
@@ -276,6 +256,7 @@ private[ann] object ActivationFunction {
}
}
+ // TODO: use Breeze UFunc
def apply(
x1: BDM[Double],
x2: BDM[Double],
@@ -294,179 +275,86 @@ private[ann] object ActivationFunction {
}
/**
- * Implements SoftMax activation function
- */
-private[ann] class SoftmaxFunction extends ActivationFunction {
- override def eval(x: BDM[Double], y: BDM[Double]): Unit = {
- var j = 0
- // find max value to make sure later that exponent is computable
- while (j < x.cols) {
- var i = 0
- var max = Double.MinValue
- while (i < x.rows) {
- if (x(i, j) > max) {
- max = x(i, j)
- }
- i += 1
- }
- var sum = 0.0
- i = 0
- while (i < x.rows) {
- val res = Math.exp(x(i, j) - max)
- y(i, j) = res
- sum += res
- i += 1
- }
- i = 0
- while (i < x.rows) {
- y(i, j) /= sum
- i += 1
- }
- j += 1
- }
- }
-
- override def crossEntropy(
- output: BDM[Double],
- target: BDM[Double],
- result: BDM[Double]): Double = {
- def m(o: Double, t: Double): Double = o - t
- ActivationFunction(output, target, result, m)
- -Bsum( target :* Blog(output)) / output.cols
- }
-
- override def derivative(x: BDM[Double], y: BDM[Double]): Unit = {
- def sd(z: Double): Double = (1 - z) * z
- ActivationFunction(x, y, sd)
- }
-
- override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = {
- throw new UnsupportedOperationException("Sorry, squared error is not defined for SoftMax.")
- }
-}
-
-/**
* Implements Sigmoid activation function
*/
private[ann] class SigmoidFunction extends ActivationFunction {
- override def eval(x: BDM[Double], y: BDM[Double]): Unit = {
- def s(z: Double): Double = Bsigmoid(z)
- ActivationFunction(x, y, s)
- }
-
- override def crossEntropy(
- output: BDM[Double],
- target: BDM[Double],
- result: BDM[Double]): Double = {
- def m(o: Double, t: Double): Double = o - t
- ActivationFunction(output, target, result, m)
- -Bsum(target :* Blog(output)) / output.cols
- }
- override def derivative(x: BDM[Double], y: BDM[Double]): Unit = {
- def sd(z: Double): Double = (1 - z) * z
- ActivationFunction(x, y, sd)
- }
+ override def eval: (Double) => Double = x => 1.0 / (1 + math.exp(-x))
- override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = {
- // TODO: make it readable
- def m(o: Double, t: Double): Double = (o - t)
- ActivationFunction(output, target, result, m)
- val e = Bsum(result :* result) / 2 / output.cols
- def m2(x: Double, o: Double) = x * (o - o * o)
- ActivationFunction(result, output, result, m2)
- e
- }
+ override def derivative: (Double) => Double = z => (1 - z) * z
}
/**
* Functional layer properties, y = f(x)
+ *
* @param activationFunction activation function
*/
private[ann] class FunctionalLayer (val activationFunction: ActivationFunction) extends Layer {
- override def getInstance(weights: Vector, position: Int): LayerModel = getInstance(0L)
- override def getInstance(seed: Long): LayerModel =
- FunctionalLayerModel(this)
+ override val weightSize = 0
+
+ override def getOutputSize(inputSize: Int): Int = inputSize
+
+ override val inPlace = true
+
+ override def createModel(weights: BDV[Double]): LayerModel = new FunctionalLayerModel(this)
+
+ override def initModel(weights: BDV[Double], random: Random): LayerModel =
+ createModel(weights)
}
/**
* Functional layer model. Holds no weights.
- * @param activationFunction activation function
+ *
+ * @param layer functiona layer
*/
-private[ann] class FunctionalLayerModel private (val activationFunction: ActivationFunction)
+private[ann] class FunctionalLayerModel private[ann] (val layer: FunctionalLayer)
extends LayerModel {
- val size = 0
- // matrices for in-place computations
- // outputs
- private var f: BDM[Double] = null
- // delta
- private var d: BDM[Double] = null
- // matrix for error computation
- private var e: BDM[Double] = null
- // delta gradient
- private lazy val dg = new Array[Double](0)
- override def eval(data: BDM[Double]): BDM[Double] = {
- if (f == null || f.cols != data.cols) f = new BDM[Double](data.rows, data.cols)
- activationFunction.eval(data, f)
- f
- }
+ // empty weights
+ val weights = new BDV[Double](0)
- override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = {
- if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](nextDelta.rows, nextDelta.cols)
- activationFunction.derivative(input, d)
- d :*= nextDelta
- d
+ override def eval(data: BDM[Double], output: BDM[Double]): Unit = {
+ ApplyInPlace(data, output, layer.activationFunction.eval)
}
- override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = dg
-
- override def weights(): Vector = Vectors.dense(new Array[Double](0))
-
- def crossEntropy(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = {
- if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols)
- val error = activationFunction.crossEntropy(output, target, e)
- (e, error)
+ override def computePrevDelta(
+ nextDelta: BDM[Double],
+ input: BDM[Double],
+ delta: BDM[Double]): Unit = {
+ ApplyInPlace(input, delta, layer.activationFunction.derivative)
+ delta :*= nextDelta
}
- def squared(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = {
- if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols)
- val error = activationFunction.squared(output, target, e)
- (e, error)
- }
-
- def error(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = {
- // TODO: allow user pick error
- activationFunction match {
- case sigmoid: SigmoidFunction => squared(output, target)
- case softmax: SoftmaxFunction => crossEntropy(output, target)
- }
- }
-}
-
-/**
- * Fabric of functional layer models
- */
-private[ann] object FunctionalLayerModel {
- def apply(layer: FunctionalLayer): FunctionalLayerModel =
- new FunctionalLayerModel(layer.activationFunction)
+ override def grad(delta: BDM[Double], input: BDM[Double], cumGrad: BDV[Double]): Unit = {}
}
/**
* Trait for the artificial neural network (ANN) topology properties
*/
-private[ann] trait Topology extends Serializable{
- def getInstance(weights: Vector): TopologyModel
- def getInstance(seed: Long): TopologyModel
+private[ann] trait Topology extends Serializable {
+ def model(weights: Vector): TopologyModel
+ def model(seed: Long): TopologyModel
}
/**
* Trait for ANN topology model
*/
-private[ann] trait TopologyModel extends Serializable{
+private[ann] trait TopologyModel extends Serializable {
+
+ val weights: Vector
+ /**
+ * Array of layers
+ */
+ val layers: Array[Layer]
+
+ /**
+ * Array of layer models
+ */
+ val layerModels: Array[LayerModel]
/**
* Forward propagation
+ *
* @param data input data
* @return array of outputs for each of the layers
*/
@@ -474,6 +362,7 @@ private[ann] trait TopologyModel extends Serializable{
/**
* Prediction of the model
+ *
* @param data input data
* @return prediction
*/
@@ -481,6 +370,7 @@ private[ann] trait TopologyModel extends Serializable{
/**
* Computes gradient for the network
+ *
* @param data input data
* @param target target output
* @param cumGradient cumulative gradient
@@ -489,22 +379,17 @@ private[ann] trait TopologyModel extends Serializable{
*/
def computeGradient(data: BDM[Double], target: BDM[Double], cumGradient: Vector,
blockSize: Int): Double
-
- /**
- * Returns the weights of the ANN
- * @return weights
- */
- def weights(): Vector
}
/**
* Feed forward ANN
+ *
* @param layers
*/
private[ann] class FeedForwardTopology private(val layers: Array[Layer]) extends Topology {
- override def getInstance(weights: Vector): TopologyModel = FeedForwardModel(this, weights)
+ override def model(weights: Vector): TopologyModel = FeedForwardModel(this, weights)
- override def getInstance(seed: Long): TopologyModel = FeedForwardModel(this, seed)
+ override def model(seed: Long): TopologyModel = FeedForwardModel(this, seed)
}
/**
@@ -513,6 +398,7 @@ private[ann] class FeedForwardTopology private(val layers: Array[Layer]) extends
private[ml] object FeedForwardTopology {
/**
* Creates a feed forward topology from the array of layers
+ *
* @param layers array of layers
* @return feed forward topology
*/
@@ -522,18 +408,26 @@ private[ml] object FeedForwardTopology {
/**
* Creates a multi-layer perceptron
+ *
* @param layerSizes sizes of layers including input and output size
- * @param softmax whether to use SoftMax or Sigmoid function for an output layer.
+ * @param softmaxOnTop wether to use SoftMax or Sigmoid function for an output layer.
* Softmax is default
* @return multilayer perceptron topology
*/
- def multiLayerPerceptron(layerSizes: Array[Int], softmax: Boolean = true): FeedForwardTopology = {
+ def multiLayerPerceptron(
+ layerSizes: Array[Int],
+ softmaxOnTop: Boolean = true): FeedForwardTopology = {
val layers = new Array[Layer]((layerSizes.length - 1) * 2)
- for(i <- 0 until layerSizes.length - 1) {
+ for (i <- 0 until layerSizes.length - 1) {
layers(i * 2) = new AffineLayer(layerSizes(i), layerSizes(i + 1))
layers(i * 2 + 1) =
- if (softmax && i == layerSizes.length - 2) {
- new FunctionalLayer(new SoftmaxFunction())
+ if (i == layerSizes.length - 2) {
+ if (softmaxOnTop) {
+ new SoftmaxLayerWithCrossEntropyLoss()
+ } else {
+ // TODO: squared error is more natural but converges slower
+ new SigmoidLayerWithSquaredError()
+ }
} else {
new FunctionalLayer(new SigmoidFunction())
}
@@ -545,17 +439,45 @@ private[ml] object FeedForwardTopology {
/**
* Model of Feed Forward Neural Network.
* Implements forward, gradient computation and can return weights in vector format.
- * @param layerModels models of layers
- * @param topology topology of the network
+ *
+ * @param weights network weights
+ * @param topology network topology
*/
private[ml] class FeedForwardModel private(
- val layerModels: Array[LayerModel],
+ val weights: Vector,
val topology: FeedForwardTopology) extends TopologyModel {
+
+ val layers = topology.layers
+ val layerModels = new Array[LayerModel](layers.length)
+ private var offset = 0
+ for (i <- 0 until layers.length) {
+ layerModels(i) = layers(i).createModel(
+ new BDV[Double](weights.toArray, offset, 1, layers(i).weightSize))
+ offset += layers(i).weightSize
+ }
+ private var outputs: Array[BDM[Double]] = null
+ private var deltas: Array[BDM[Double]] = null
+
override def forward(data: BDM[Double]): Array[BDM[Double]] = {
- val outputs = new Array[BDM[Double]](layerModels.length)
- outputs(0) = layerModels(0).eval(data)
+ // Initialize output arrays for all layers. Special treatment for InPlace
+ val currentBatchSize = data.cols
+ // TODO: allocate outputs as one big array and then create BDMs from it
+ if (outputs == null || outputs(0).cols != currentBatchSize) {
+ outputs = new Array[BDM[Double]](layers.length)
+ var inputSize = data.rows
+ for (i <- 0 until layers.length) {
+ if (layers(i).inPlace) {
+ outputs(i) = outputs(i - 1)
+ } else {
+ val outputSize = layers(i).getOutputSize(inputSize)
+ outputs(i) = new BDM[Double](outputSize, currentBatchSize)
+ inputSize = outputSize
+ }
+ }
+ }
+ layerModels(0).eval(data, outputs(0))
for (i <- 1 until layerModels.length) {
- outputs(i) = layerModels(i).eval(outputs(i-1))
+ layerModels(i).eval(outputs(i - 1), outputs(i))
}
outputs
}
@@ -566,54 +488,36 @@ private[ml] class FeedForwardModel private(
cumGradient: Vector,
realBatchSize: Int): Double = {
val outputs = forward(data)
- val deltas = new Array[BDM[Double]](layerModels.length)
+ val currentBatchSize = data.cols
+ // TODO: allocate deltas as one big array and then create BDMs from it
+ if (deltas == null || deltas(0).cols != currentBatchSize) {
+ deltas = new Array[BDM[Double]](layerModels.length)
+ var inputSize = data.rows
+ for (i <- 0 until layerModels.length - 1) {
+ val outputSize = layers(i).getOutputSize(inputSize)
+ deltas(i) = new BDM[Double](outputSize, currentBatchSize)
+ inputSize = outputSize
+ }
+ }
val L = layerModels.length - 1
- val (newE, newError) = layerModels.last match {
- case flm: FunctionalLayerModel => flm.error(outputs.last, target)
+ // TODO: explain why delta of top layer is null (because it might contain loss+layer)
+ val loss = layerModels.last match {
+ case levelWithError: LossFunction => levelWithError.loss(outputs.last, target, deltas(L - 1))
case _ =>
- throw new UnsupportedOperationException("Non-functional layer not supported at the top")
+ throw new UnsupportedOperationException("Top layer is required to have objective.")
}
- deltas(L) = new BDM[Double](0, 0)
- deltas(L - 1) = newE
for (i <- (L - 2) to (0, -1)) {
- deltas(i) = layerModels(i + 1).prevDelta(deltas(i + 1), outputs(i + 1))
- }
- val grads = new Array[Array[Double]](layerModels.length)
- for (i <- 0 until layerModels.length) {
- val input = if (i==0) data else outputs(i - 1)
- grads(i) = layerModels(i).grad(deltas(i), input)
+ layerModels(i + 1).computePrevDelta(deltas(i + 1), outputs(i + 1), deltas(i))
}
- // update cumGradient
val cumGradientArray = cumGradient.toArray
var offset = 0
- // TODO: extract roll
- for (i <- 0 until grads.length) {
- val gradArray = grads(i)
- var k = 0
- while (k < gradArray.length) {
- cumGradientArray(offset + k) += gradArray(k)
- k += 1
- }
- offset += gradArray.length
- }
- newError
- }
-
- // TODO: do we really need to copy the weights? they should be read-only
- override def weights(): Vector = {
- // TODO: extract roll
- var size = 0
- for (i <- 0 until layerModels.length) {
- size += layerModels(i).size
- }
- val array = new Array[Double](size)
- var offset = 0
for (i <- 0 until layerModels.length) {
- val layerWeights = layerModels(i).weights().toArray
- System.arraycopy(layerWeights, 0, array, offset, layerWeights.length)
- offset += layerWeights.length
+ val input = if (i == 0) data else outputs(i - 1)
+ layerModels(i).grad(deltas(i), input,
+ new BDV[Double](cumGradientArray, offset, 1, layers(i).weightSize))
+ offset += layers(i).weightSize
}
- Vectors.dense(array)
+ loss
}
override def predict(data: Vector): Vector = {
@@ -630,23 +534,19 @@ private[ann] object FeedForwardModel {
/**
* Creates a model from a topology and weights
+ *
* @param topology topology
* @param weights weights
* @return model
*/
def apply(topology: FeedForwardTopology, weights: Vector): FeedForwardModel = {
- val layers = topology.layers
- val layerModels = new Array[LayerModel](layers.length)
- var offset = 0
- for (i <- 0 until layers.length) {
- layerModels(i) = layers(i).getInstance(weights, offset)
- offset += layerModels(i).size
- }
- new FeedForwardModel(layerModels, topology)
+ // TODO: check that weights size is equal to sum of layers sizes
+ new FeedForwardModel(weights, topology)
}
/**
* Creates a model given a topology and seed
+ *
* @param topology topology
* @param seed seed for generating the weights
* @return model
@@ -654,17 +554,25 @@ private[ann] object FeedForwardModel {
def apply(topology: FeedForwardTopology, seed: Long = 11L): FeedForwardModel = {
val layers = topology.layers
val layerModels = new Array[LayerModel](layers.length)
+ var totalSize = 0
+ for (i <- 0 until topology.layers.length) {
+ totalSize += topology.layers(i).weightSize
+ }
+ val weights = BDV.zeros[Double](totalSize)
var offset = 0
- for(i <- 0 until layers.length) {
- layerModels(i) = layers(i).getInstance(seed)
- offset += layerModels(i).size
+ val random = new XORShiftRandom(seed)
+ for (i <- 0 until layers.length) {
+ layerModels(i) = layers(i).
+ initModel(new BDV[Double](weights.data, offset, 1, layers(i).weightSize), random)
+ offset += layers(i).weightSize
}
- new FeedForwardModel(layerModels, topology)
+ new FeedForwardModel(Vectors.fromBreeze(weights), topology)
}
}
/**
* Neural network gradient. Does nothing but calling Model's gradient
+ *
* @param topology topology
* @param dataStacker data stacker
*/
@@ -682,7 +590,7 @@ private[ann] class ANNGradient(topology: Topology, dataStacker: DataStacker) ext
weights: Vector,
cumGradient: Vector): Double = {
val (input, target, realBatchSize) = dataStacker.unstack(data)
- val model = topology.getInstance(weights)
+ val model = topology.model(weights)
model.computeGradient(input, target, cumGradient, realBatchSize)
}
}
@@ -692,6 +600,7 @@ private[ann] class ANNGradient(topology: Topology, dataStacker: DataStacker) ext
* through Optimizer/Gradient interfaces. If stackSize is more than one, makes blocks
* or matrices of inputs and outputs and then stack them in one vector.
* This can be used for further batch computations after unstacking.
+ *
* @param stackSize stack size
* @param inputSize size of the input vectors
* @param outputSize size of the output vectors
@@ -701,6 +610,7 @@ private[ann] class DataStacker(stackSize: Int, inputSize: Int, outputSize: Int)
/**
* Stacks the data
+ *
* @param data RDD of vector pairs
* @return RDD of double (always zero) and vector that contains the stacked vectors
*/
@@ -733,6 +643,7 @@ private[ann] class DataStacker(stackSize: Int, inputSize: Int, outputSize: Int)
/**
* Unstack the stacked vectors into matrices for batch operations
+ *
* @param data stacked vector
* @return pair of matrices holding input and output data and the real stack size
*/
@@ -765,6 +676,7 @@ private[ann] class ANNUpdater extends Updater {
/**
* MLlib-style trainer class that trains a network given the data and topology
+ *
* @param topology topology of ANN
* @param inputSize input size
* @param outputSize output size
@@ -774,8 +686,8 @@ private[ml] class FeedForwardTrainer(
val inputSize: Int,
val outputSize: Int) extends Serializable {
- // TODO: what if we need to pass random seed?
- private var _weights = topology.getInstance(11L).weights()
+ private var _seed = this.getClass.getName.hashCode.toLong
+ private var _weights: Vector = null
private var _stackSize = 128
private var dataStacker = new DataStacker(_stackSize, inputSize, outputSize)
private var _gradient: Gradient = new ANNGradient(topology, dataStacker)
@@ -783,27 +695,41 @@ private[ml] class FeedForwardTrainer(
private var optimizer: Optimizer = LBFGSOptimizer.setConvergenceTol(1e-4).setNumIterations(100)
/**
+ * Returns seed
+ */
+ def getSeed: Long = _seed
+
+ /**
+ * Sets seed
+ */
+ def setSeed(value: Long): this.type = {
+ _seed = value
+ this
+ }
+
+ /**
* Returns weights
- * @return weights
*/
def getWeights: Vector = _weights
/**
* Sets weights
+ *
* @param value weights
* @return trainer
*/
- def setWeights(value: Vector): FeedForwardTrainer = {
+ def setWeights(value: Vector): this.type = {
_weights = value
this
}
/**
* Sets the stack size
+ *
* @param value stack size
* @return trainer
*/
- def setStackSize(value: Int): FeedForwardTrainer = {
+ def setStackSize(value: Int): this.type = {
_stackSize = value
dataStacker = new DataStacker(value, inputSize, outputSize)
this
@@ -811,6 +737,7 @@ private[ml] class FeedForwardTrainer(
/**
* Sets the SGD optimizer
+ *
* @return SGD optimizer
*/
def SGDOptimizer: GradientDescent = {
@@ -821,6 +748,7 @@ private[ml] class FeedForwardTrainer(
/**
* Sets the LBFGS optimizer
+ *
* @return LBGS optimizer
*/
def LBFGSOptimizer: LBFGS = {
@@ -831,10 +759,11 @@ private[ml] class FeedForwardTrainer(
/**
* Sets the updater
+ *
* @param value updater
* @return trainer
*/
- def setUpdater(value: Updater): FeedForwardTrainer = {
+ def setUpdater(value: Updater): this.type = {
_updater = value
updateUpdater(value)
this
@@ -842,10 +771,11 @@ private[ml] class FeedForwardTrainer(
/**
* Sets the gradient
+ *
* @param value gradient
* @return trainer
*/
- def setGradient(value: Gradient): FeedForwardTrainer = {
+ def setGradient(value: Gradient): this.type = {
_gradient = value
updateGradient(value)
this
@@ -871,12 +801,20 @@ private[ml] class FeedForwardTrainer(
/**
* Trains the ANN
+ *
* @param data RDD of input and output vector pairs
* @return model
*/
def train(data: RDD[(Vector, Vector)]): TopologyModel = {
- val newWeights = optimizer.optimize(dataStacker.stack(data), getWeights)
- topology.getInstance(newWeights)
+ val w = if (getWeights == null) {
+ // TODO: will make a copy if vector is a subvector of BDV (see Vectors code)
+ topology.model(_seed).weights
+ } else {
+ getWeights
+ }
+ // TODO: deprecate standard optimizer because it needs Vector
+ val newWeights = optimizer.optimize(dataStacker.stack(data), w)
+ topology.model(newWeights)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala
new file mode 100644
index 0000000000..32d78e9b22
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.ann
+
+import java.util.Random
+
+import breeze.linalg.{sum => Bsum, DenseMatrix => BDM, DenseVector => BDV}
+import breeze.numerics.{log => brzlog}
+
+/**
+ * Trait for loss function
+ */
+private[ann] trait LossFunction {
+ /**
+ * Returns the value of loss function.
+ * Computes loss based on target and output.
+ * Writes delta (error) to delta in place.
+ * Delta is allocated based on the outputSize
+ * of model implementation.
+ *
+ * @param output actual output
+ * @param target target output
+ * @param delta delta (updated in place)
+ * @return loss
+ */
+ def loss(output: BDM[Double], target: BDM[Double], delta: BDM[Double]): Double
+}
+
+private[ann] class SigmoidLayerWithSquaredError extends Layer {
+ override val weightSize = 0
+ override val inPlace = true
+
+ override def getOutputSize(inputSize: Int): Int = inputSize
+ override def createModel(weights: BDV[Double]): LayerModel =
+ new SigmoidLayerModelWithSquaredError()
+ override def initModel(weights: BDV[Double], random: Random): LayerModel =
+ new SigmoidLayerModelWithSquaredError()
+}
+
+private[ann] class SigmoidLayerModelWithSquaredError
+ extends FunctionalLayerModel(new FunctionalLayer(new SigmoidFunction)) with LossFunction {
+ override def loss(output: BDM[Double], target: BDM[Double], delta: BDM[Double]): Double = {
+ ApplyInPlace(output, target, delta, (o: Double, t: Double) => o - t)
+ val error = Bsum(delta :* delta) / 2 / output.cols
+ ApplyInPlace(delta, output, delta, (x: Double, o: Double) => x * (o - o * o))
+ error
+ }
+}
+
+private[ann] class SoftmaxLayerWithCrossEntropyLoss extends Layer {
+ override val weightSize = 0
+ override val inPlace = true
+
+ override def getOutputSize(inputSize: Int): Int = inputSize
+ override def createModel(weights: BDV[Double]): LayerModel =
+ new SoftmaxLayerModelWithCrossEntropyLoss()
+ override def initModel(weights: BDV[Double], random: Random): LayerModel =
+ new SoftmaxLayerModelWithCrossEntropyLoss()
+}
+
+private[ann] class SoftmaxLayerModelWithCrossEntropyLoss extends LayerModel with LossFunction {
+
+ // loss layer models do not have weights
+ val weights = new BDV[Double](0)
+
+ override def eval(data: BDM[Double], output: BDM[Double]): Unit = {
+ var j = 0
+ // find max value to make sure later that exponent is computable
+ while (j < data.cols) {
+ var i = 0
+ var max = Double.MinValue
+ while (i < data.rows) {
+ if (data(i, j) > max) {
+ max = data(i, j)
+ }
+ i += 1
+ }
+ var sum = 0.0
+ i = 0
+ while (i < data.rows) {
+ val res = math.exp(data(i, j) - max)
+ output(i, j) = res
+ sum += res
+ i += 1
+ }
+ i = 0
+ while (i < data.rows) {
+ output(i, j) /= sum
+ i += 1
+ }
+ j += 1
+ }
+ }
+ override def computePrevDelta(
+ nextDelta: BDM[Double],
+ input: BDM[Double],
+ delta: BDM[Double]): Unit = {
+ /* loss layer model computes delta in loss function */
+ }
+
+ override def grad(delta: BDM[Double], input: BDM[Double], cumGrad: BDV[Double]): Unit = {
+ /* loss layer model does not have weights */
+ }
+
+ override def loss(output: BDM[Double], target: BDM[Double], delta: BDM[Double]): Double = {
+ ApplyInPlace(output, target, delta, (o: Double, t: Double) => o - t)
+ -Bsum( target :* brzlog(output)) / output.cols
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 8186afc17a..473e801794 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.param.shared.HasRawPredictionCol
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, StructType}
@@ -92,7 +92,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
* @param dataset input dataset
* @return transformed dataset
*/
- override def transform(dataset: DataFrame): DataFrame = {
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
// Output selected columns only.
@@ -123,7 +123,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" +
" since no output columns were set.")
}
- outputData
+ outputData.toDF
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 3e4b21bff6..300ae4339c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -32,7 +32,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
/**
@@ -82,7 +82,7 @@ final class DecisionTreeClassifier @Since("1.4.0") (
@Since("1.6.0")
override def setSeed(value: Long): this.type = super.setSeed(value)
- override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = {
+ override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
@@ -203,9 +203,9 @@ final class DecisionTreeClassificationModel private[ml] (
* to determine feature importance instead.
*/
@Since("2.0.0")
- lazy val featureImportances: Vector = RandomForest.featureImportances(this, numFeatures)
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures)
- /** Convert to spark.mllib DecisionTreeModel (losing some infomation) */
+ /** Convert to spark.mllib DecisionTreeModel (losing some information) */
override private[spark] def toOld: OldDecisionTreeModel = {
new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index c31df3aa18..39a698af15 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -18,23 +18,24 @@
package org.apache.spark.ml.classification
import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{PredictionModel, Predictor}
-import org.apache.spark.ml.param.{Param, ParamMap}
+import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
-import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeClassifierParams,
- TreeEnsembleModel}
+import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
-import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
/**
@@ -43,13 +44,23 @@ import org.apache.spark.sql.functions._
* learning algorithm for classification.
* It supports binary labels, as well as both continuous and categorical features.
* Note: Multiclass labels are not currently supported.
+ *
+ * The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999.
+ *
+ * Notes on Gradient Boosting vs. TreeBoost:
+ * - This implementation is for Stochastic Gradient Boosting, not for TreeBoost.
+ * - Both algorithms learn tree ensembles by minimizing loss functions.
+ * - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes
+ * based on the loss function, whereas the original gradient boosting method does not.
+ * - We expect to implement TreeBoost in the future:
+ * [https://issues.apache.org/jira/browse/SPARK-4240]
*/
@Since("1.4.0")
@Experimental
final class GBTClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends Predictor[Vector, GBTClassifier, GBTClassificationModel]
- with GBTParams with TreeClassifierParams with Logging {
+ with GBTClassifierParams with DefaultParamsWritable with Logging {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("gbtc"))
@@ -106,41 +117,13 @@ final class GBTClassifier @Since("1.4.0") (
@Since("1.4.0")
override def setStepSize(value: Double): this.type = super.setStepSize(value)
- // Parameters for GBTClassifier:
-
- /**
- * Loss function which GBT tries to minimize. (case-insensitive)
- * Supported: "logistic"
- * (default = logistic)
- * @group param
- */
- @Since("1.4.0")
- val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
- " tries to minimize (case-insensitive). Supported options:" +
- s" ${GBTClassifier.supportedLossTypes.mkString(", ")}",
- (value: String) => GBTClassifier.supportedLossTypes.contains(value.toLowerCase))
-
- setDefault(lossType -> "logistic")
+ // Parameters from GBTClassifierParams:
/** @group setParam */
@Since("1.4.0")
def setLossType(value: String): this.type = set(lossType, value)
- /** @group getParam */
- @Since("1.4.0")
- def getLossType: String = $(lossType).toLowerCase
-
- /** (private[ml]) Convert new loss to old loss. */
- override private[ml] def getOldLossType: OldLoss = {
- getLossType match {
- case "logistic" => OldLogLoss
- case _ =>
- // Should never happen because of check in setter method.
- throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType")
- }
- }
-
- override protected def train(dataset: DataFrame): GBTClassificationModel = {
+ override protected def train(dataset: Dataset[_]): GBTClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
@@ -166,11 +149,14 @@ final class GBTClassifier @Since("1.4.0") (
@Since("1.4.0")
@Experimental
-object GBTClassifier {
- // The losses below should be lowercase.
+object GBTClassifier extends DefaultParamsReadable[GBTClassifier] {
+
/** Accessor for supported loss settings: logistic */
@Since("1.4.0")
- final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase)
+ final val supportedLossTypes: Array[String] = GBTClassifierParams.supportedLossTypes
+
+ @Since("2.0.0")
+ override def load(path: String): GBTClassifier = super.load(path)
}
/**
@@ -190,9 +176,10 @@ final class GBTClassificationModel private[ml](
private val _treeWeights: Array[Double],
@Since("1.6.0") override val numFeatures: Int)
extends PredictionModel[Vector, GBTClassificationModel]
- with TreeEnsembleModel with Serializable {
+ with GBTClassifierParams with TreeEnsembleModel[DecisionTreeRegressionModel]
+ with MLWritable with Serializable {
- require(numTrees > 0, "GBTClassificationModel requires at least 1 tree.")
+ require(_trees.nonEmpty, "GBTClassificationModel requires at least 1 tree.")
require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" +
s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
@@ -206,12 +193,12 @@ final class GBTClassificationModel private[ml](
this(uid, _trees, _treeWeights, -1)
@Since("1.4.0")
- override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+ override def trees: Array[DecisionTreeRegressionModel] = _trees
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
- override protected def transformImpl(dataset: DataFrame): DataFrame = {
+ override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
@@ -227,6 +214,9 @@ final class GBTClassificationModel private[ml](
if (prediction > 0.0) 1.0 else 0.0
}
+ /** Number of trees in ensemble */
+ val numTrees: Int = trees.length
+
@Since("1.4.0")
override def copy(extra: ParamMap): GBTClassificationModel = {
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures),
@@ -238,16 +228,79 @@ final class GBTClassificationModel private[ml](
s"GBTClassificationModel (uid=$uid) with $numTrees trees"
}
+ /**
+ * Estimate of the importance of each feature.
+ *
+ * Each feature's importance is the average of its importance across all trees in the ensemble
+ * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ * and follows the implementation from scikit-learn.
+ *
+ * @see [[DecisionTreeClassificationModel.featureImportances]]
+ */
+ @Since("2.0.0")
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
+
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldGBTModel = {
new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights)
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this)
}
-private[ml] object GBTClassificationModel {
+@Since("2.0.0")
+object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[GBTClassificationModel] = new GBTClassificationModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): GBTClassificationModel = super.load(path)
+
+ private[GBTClassificationModel]
+ class GBTClassificationModelWriter(instance: GBTClassificationModel) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures,
+ "numTrees" -> instance.getNumTrees)
+ EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
+ }
+ }
+
+ private class GBTClassificationModelReader extends MLReader[GBTClassificationModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[GBTClassificationModel].getName
+ private val treeClassName = classOf[DecisionTreeRegressionModel].getName
+
+ override def load(path: String): GBTClassificationModel = {
+ implicit val format = DefaultFormats
+ val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
+ EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+
+ val trees: Array[DecisionTreeRegressionModel] = treesData.map {
+ case (treeMetadata, root) =>
+ val tree =
+ new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
+ DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+ tree
+ }
+ require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" +
+ s" trees based on metadata but found ${trees.length} trees.")
+ val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
- /** (private[ml]) Convert a model from the old API */
- def fromOld(
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
oldModel: OldGBTModel,
parent: GBTClassifier,
categoricalFeatures: Map[Int, Int],
@@ -259,6 +312,6 @@ private[ml] object GBTClassificationModel {
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc")
- new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures)
+ new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 861b1d4b66..c2b440059b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -36,8 +36,9 @@ import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.types.DoubleType
import org.apache.spark.storage.StorageLevel
/**
@@ -256,22 +257,26 @@ class LogisticRegression @Since("1.2.0") (
this
}
- override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = {
+ override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = {
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
train(dataset, handlePersistence)
}
- protected[spark] def train(dataset: DataFrame, handlePersistence: Boolean):
+ protected[spark] def train(dataset: Dataset[_], handlePersistence: Boolean):
LogisticRegressionModel = {
val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
- dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
+ dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
+ val instr = Instrumentation.create(this, instances)
+ instr.logParams(regParam, elasticNetParam, standardization, threshold,
+ maxIter, tol, fitIntercept)
+
val (summarizer, labelSummarizer) = {
val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer),
instance: Instance) =>
@@ -290,6 +295,9 @@ class LogisticRegression @Since("1.2.0") (
val numClasses = histogram.length
val numFeatures = summarizer.mean.size
+ instr.logNumClasses(numClasses)
+ instr.logNumFeatures(numFeatures)
+
val (coefficients, intercept, objectiveHistory) = {
if (numInvalid != 0) {
val msg = s"Classification labels should be in {0 to ${numClasses - 1} " +
@@ -361,7 +369,7 @@ class LogisticRegression @Since("1.2.0") (
if (optInitialModel.isDefined && optInitialModel.get.coefficients.size != numFeatures) {
val vec = optInitialModel.get.coefficients
logWarning(
- s"Initial coefficients provided ${vec} did not match the expected size ${numFeatures}")
+ s"Initial coefficients provided $vec did not match the expected size $numFeatures")
}
if (optInitialModel.isDefined && optInitialModel.get.coefficients.size == numFeatures) {
@@ -443,7 +451,9 @@ class LogisticRegression @Since("1.2.0") (
$(labelCol),
$(featuresCol),
objectiveHistory)
- model.setSummary(logRegSummary)
+ val m = model.setSummary(logRegSummary)
+ instr.logSuccess(m)
+ m
}
@Since("1.4.0")
@@ -522,7 +532,7 @@ class LogisticRegressionModel private[spark] (
(LogisticRegressionModel, String) = {
$(probabilityCol) match {
case "" =>
- val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString()
+ val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString
(copy(ParamMap.empty).setProbabilityCol(probabilityColName), probabilityColName)
case p => (this, p)
}
@@ -539,13 +549,15 @@ class LogisticRegressionModel private[spark] (
def hasSummary: Boolean = trainingSummary.isDefined
/**
- * Evaluates the model on a testset.
+ * Evaluates the model on a test dataset.
* @param dataset Test dataset to evaluate model on.
*/
- // TODO: decide on a good name before exposing to public API
- private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = {
- new BinaryLogisticRegressionSummary(
- this.transform(dataset), $(probabilityCol), $(labelCol), $(featuresCol))
+ @Since("2.0.0")
+ def evaluate(dataset: Dataset[_]): LogisticRegressionSummary = {
+ // Handle possible missing or invalid prediction columns
+ val (summaryModel, probabilityColName) = findSummaryModelAndProbabilityCol()
+ new BinaryLogisticRegressionSummary(summaryModel.transform(dataset),
+ probabilityColName, $(labelCol), $(featuresCol))
}
/**
@@ -771,13 +783,13 @@ sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary
*/
sealed trait LogisticRegressionSummary extends Serializable {
- /** Dataframe outputted by the model's `transform` method. */
+ /** Dataframe output by the model's `transform` method. */
def predictions: DataFrame
- /** Field in "predictions" which gives the calibrated probability of each instance as a vector. */
+ /** Field in "predictions" which gives the probability of each class as a vector. */
def probabilityCol: String
- /** Field in "predictions" which gives the true label of each instance. */
+ /** Field in "predictions" which gives the true label of each instance (if available). */
def labelCol: String
/** Field in "predictions" which gives the features of each instance as a vector. */
@@ -789,9 +801,9 @@ sealed trait LogisticRegressionSummary extends Serializable {
* :: Experimental ::
* Logistic regression training results.
*
- * @param predictions dataframe outputted by the model's `transform` method.
- * @param probabilityCol field in "predictions" which gives the calibrated probability of
- * each instance as a vector.
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param probabilityCol field in "predictions" which gives the probability of
+ * each class as a vector.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
@@ -813,9 +825,9 @@ class BinaryLogisticRegressionTrainingSummary private[classification] (
* :: Experimental ::
* Binary Logistic regression results for a given model.
*
- * @param predictions dataframe outputted by the model's `transform` method.
- * @param probabilityCol field in "predictions" which gives the calibrated probability of
- * each instance.
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param probabilityCol field in "predictions" which gives the probability of
+ * each class as a vector.
* @param labelCol field in "predictions" which gives the true label of each instance.
* @param featuresCol field in "predictions" which gives the features of each instance as a vector.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index f6de5f2df4..9ff5252e4f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -24,27 +24,27 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer}
-import org.apache.spark.ml.param.{IntArrayParam, IntParam, ParamMap, ParamValidators}
-import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasTol}
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasStepSize, HasTol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
/** Params for Multilayer Perceptron. */
private[ml] trait MultilayerPerceptronParams extends PredictorParams
- with HasSeed with HasMaxIter with HasTol {
+ with HasSeed with HasMaxIter with HasTol with HasStepSize {
/**
* Layer sizes including input size and output size.
* Default: Array(1, 1)
- * @group param
+ *
+ * @group param
*/
final val layers: IntArrayParam = new IntArrayParam(this, "layers",
"Sizes of layers from input layer to output layer" +
" E.g., Array(780, 100, 10) means 780 inputs, " +
"one hidden layer with 100 neurons and output layer of 10 neurons.",
- // TODO: how to check ALSO that all elements are greater than 0?
- ParamValidators.arrayLengthGt(1)
+ (t: Array[Int]) => t.forall(ParamValidators.gt(0)) && t.length > 1
)
/** @group getParam */
@@ -56,7 +56,8 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams
* a partition then it is adjusted to the size of this data.
* Recommended size is between 10 and 1000.
* Default: 128
- * @group expertParam
+ *
+ * @group expertParam
*/
final val blockSize: IntParam = new IntParam(this, "blockSize",
"Block size for stacking input data in matrices. Data is stacked within partitions." +
@@ -67,7 +68,33 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams
/** @group getParam */
final def getBlockSize: Int = $(blockSize)
- setDefault(maxIter -> 100, tol -> 1e-4, layers -> Array(1, 1), blockSize -> 128)
+ /**
+ * Allows setting the solver: minibatch gradient descent (gd) or l-bfgs.
+ * l-bfgs is the default one.
+ *
+ * @group expertParam
+ */
+ final val solver: Param[String] = new Param[String](this, "solver",
+ " Allows setting the solver: minibatch gradient descent (gd) or l-bfgs. " +
+ " l-bfgs is the default one.",
+ ParamValidators.inArray[String](Array("gd", "l-bfgs")))
+
+ /** @group getParam */
+ final def getOptimizer: String = $(solver)
+
+ /**
+ * Model weights. Can be returned either after training or after explicit setting
+ *
+ * @group expertParam
+ */
+ final val weights: Param[Vector] = new Param[Vector](this, "weights",
+ " Sets the weights of the model ")
+
+ /** @group getParam */
+ final def getWeights: Vector = $(weights)
+
+
+ setDefault(maxIter -> 100, tol -> 1e-4, blockSize -> 128, solver -> "l-bfgs", stepSize -> 0.03)
}
/** Label to vector converter. */
@@ -106,6 +133,7 @@ private object LabelConverter {
* Each layer has sigmoid activation function, output layer has softmax.
* Number of inputs has to be equal to the size of feature vectors.
* Number of outputs has to be equal to the total number of labels.
+ *
*/
@Since("1.5.0")
@Experimental
@@ -128,7 +156,8 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
/**
* Set the maximum number of iterations.
* Default is 100.
- * @group setParam
+ *
+ * @group setParam
*/
@Since("1.5.0")
def setMaxIter(value: Int): this.type = set(maxIter, value)
@@ -137,18 +166,28 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
* Set the convergence tolerance of iterations.
* Smaller value will lead to higher accuracy with the cost of more iterations.
* Default is 1E-4.
- * @group setParam
+ *
+ * @group setParam
*/
@Since("1.5.0")
def setTol(value: Double): this.type = set(tol, value)
/**
- * Set the seed for weights initialization.
- * @group setParam
+ * Set the seed for weights initialization if weights are not set
+ *
+ * @group setParam
*/
@Since("1.5.0")
def setSeed(value: Long): this.type = set(seed, value)
+ /**
+ * Sets the model weights.
+ *
+ * @group expertParam
+ */
+ @Since("2.0.0")
+ def setWeights(value: Vector): this.type = set(weights, value)
+
@Since("1.5.0")
override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra)
@@ -160,17 +199,24 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
* @param dataset Training dataset
* @return Fitted model
*/
- override protected def train(dataset: DataFrame): MultilayerPerceptronClassificationModel = {
+ override protected def train(dataset: Dataset[_]): MultilayerPerceptronClassificationModel = {
val myLayers = $(layers)
val labels = myLayers.last
val lpData = extractLabeledPoints(dataset)
val data = lpData.map(lp => LabelConverter.encodeLabeledPoint(lp, labels))
val topology = FeedForwardTopology.multiLayerPerceptron(myLayers, true)
- val FeedForwardTrainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last)
- FeedForwardTrainer.LBFGSOptimizer.setConvergenceTol($(tol)).setNumIterations($(maxIter))
- FeedForwardTrainer.setStackSize($(blockSize))
- val mlpModel = FeedForwardTrainer.train(data)
- new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights())
+ val trainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last)
+ if (isDefined(weights)) {
+ trainer.setWeights($(weights))
+ } else {
+ trainer.setSeed($(seed))
+ }
+ trainer.LBFGSOptimizer
+ .setConvergenceTol($(tol))
+ .setNumIterations($(maxIter))
+ trainer.setStackSize($(blockSize))
+ val mlpModel = trainer.train(data)
+ new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights)
}
}
@@ -186,7 +232,8 @@ object MultilayerPerceptronClassifier
* :: Experimental ::
* Classification model based on the Multilayer Perceptron.
* Each layer has sigmoid activation function, output layer has softmax.
- * @param uid uid
+ *
+ * @param uid uid
* @param layers array of layer sizes including input and output layers
* @param weights vector of initial weights for the model that consists of the weights of layers
* @return prediction model
@@ -203,7 +250,7 @@ class MultilayerPerceptronClassificationModel private[ml] (
@Since("1.6.0")
override val numFeatures: Int = layers.head
- private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights)
+ private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).model(weights)
/**
* Returns layers in a Java List.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 483ef0d88c..267d63b51e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -29,7 +29,7 @@ import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesMo
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
/**
* Params for Naive Bayes Classifiers.
@@ -101,7 +101,7 @@ class NaiveBayes @Since("1.5.0") (
def setModelType(value: String): this.type = set(modelType, value)
setDefault(modelType -> OldNaiveBayes.Multinomial)
- override protected def train(dataset: DataFrame): NaiveBayesModel = {
+ override protected def train(dataset: Dataset[_]): NaiveBayesModel = {
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType))
NaiveBayesModel.fromOld(oldModel, this)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index c41a611f1c..4de1b877b0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -21,22 +21,24 @@ import java.util.UUID
import scala.language.existentials
+import org.apache.hadoop.fs.Path
+import org.json4s.{DefaultFormats, JObject, _}
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.param.{Param, ParamMap}
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
-/**
- * Params for [[OneVsRest]].
- */
-private[ml] trait OneVsRestParams extends PredictorParams {
-
+private[ml] trait ClassifierTypeTrait {
// scalastyle:off structural.type
type ClassifierType = Classifier[F, E, M] forSome {
type F
@@ -44,6 +46,12 @@ private[ml] trait OneVsRestParams extends PredictorParams {
type E <: Classifier[F, E, M]
}
// scalastyle:on structural.type
+}
+
+/**
+ * Params for [[OneVsRest]].
+ */
+private[ml] trait OneVsRestParams extends PredictorParams with ClassifierTypeTrait {
/**
* param for the base binary classifier that we reduce multiclass classification into.
@@ -57,6 +65,55 @@ private[ml] trait OneVsRestParams extends PredictorParams {
def getClassifier: ClassifierType = $(classifier)
}
+private[ml] object OneVsRestParams extends ClassifierTypeTrait {
+
+ def validateParams(instance: OneVsRestParams): Unit = {
+ def checkElement(elem: Params, name: String): Unit = elem match {
+ case stage: MLWritable => // good
+ case other =>
+ throw new UnsupportedOperationException("OneVsRest write will fail " +
+ s" because it contains $name which does not implement MLWritable." +
+ s" Non-Writable $name: ${other.uid} of type ${other.getClass}")
+ }
+
+ instance match {
+ case ovrModel: OneVsRestModel => ovrModel.models.foreach(checkElement(_, "model"))
+ case _ => // no need to check OneVsRest here
+ }
+
+ checkElement(instance.getClassifier, "classifier")
+ }
+
+ def saveImpl(
+ path: String,
+ instance: OneVsRestParams,
+ sc: SparkContext,
+ extraMetadata: Option[JObject] = None): Unit = {
+
+ val params = instance.extractParamMap().toSeq
+ val jsonParams = render(params
+ .filter { case ParamPair(p, v) => p.name != "classifier" }
+ .map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) }
+ .toList)
+
+ DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
+
+ val classifierPath = new Path(path, "classifier").toString
+ instance.getClassifier.asInstanceOf[MLWritable].save(classifierPath)
+ }
+
+ def loadImpl(
+ path: String,
+ sc: SparkContext,
+ expectedClassName: String): (DefaultParamsReader.Metadata, ClassifierType) = {
+
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
+ val classifierPath = new Path(path, "classifier").toString
+ val estimator = DefaultParamsReader.loadParamsInstance[ClassifierType](classifierPath, sc)
+ (metadata, estimator)
+ }
+}
+
/**
* :: Experimental ::
* Model produced by [[OneVsRest]].
@@ -73,18 +130,18 @@ private[ml] trait OneVsRestParams extends PredictorParams {
@Since("1.4.0")
@Experimental
final class OneVsRestModel private[ml] (
- @Since("1.4.0") override val uid: String,
- @Since("1.4.0") labelMetadata: Metadata,
+ @Since("1.4.0") override val uid: String,
+ private[ml] val labelMetadata: Metadata,
@Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]])
- extends Model[OneVsRestModel] with OneVsRestParams {
+ extends Model[OneVsRestModel] with OneVsRestParams with MLWritable {
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
}
- @Since("1.4.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
// Check schema
transformSchema(dataset.schema, logging = true)
@@ -143,6 +200,56 @@ final class OneVsRestModel private[ml] (
uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]]))
copyValues(copied, extra).setParent(parent)
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new OneVsRestModel.OneVsRestModelWriter(this)
+}
+
+@Since("2.0.0")
+object OneVsRestModel extends MLReadable[OneVsRestModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[OneVsRestModel] = new OneVsRestModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): OneVsRestModel = super.load(path)
+
+ /** [[MLWriter]] instance for [[OneVsRestModel]] */
+ private[OneVsRestModel] class OneVsRestModelWriter(instance: OneVsRestModel) extends MLWriter {
+
+ OneVsRestParams.validateParams(instance)
+
+ override protected def saveImpl(path: String): Unit = {
+ val extraJson = ("labelMetadata" -> instance.labelMetadata.json) ~
+ ("numClasses" -> instance.models.length)
+ OneVsRestParams.saveImpl(path, instance, sc, Some(extraJson))
+ instance.models.zipWithIndex.foreach { case (model: MLWritable, idx) =>
+ val modelPath = new Path(path, s"model_$idx").toString
+ model.save(modelPath)
+ }
+ }
+ }
+
+ private class OneVsRestModelReader extends MLReader[OneVsRestModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[OneVsRestModel].getName
+
+ override def load(path: String): OneVsRestModel = {
+ implicit val format = DefaultFormats
+ val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className)
+ val labelMetadata = Metadata.fromJson((metadata.metadata \ "labelMetadata").extract[String])
+ val numClasses = (metadata.metadata \ "numClasses").extract[Int]
+ val models = Range(0, numClasses).toArray.map { idx =>
+ val modelPath = new Path(path, s"model_$idx").toString
+ DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sc)
+ }
+ val ovrModel = new OneVsRestModel(metadata.uid, labelMetadata, models)
+ DefaultParamsReader.getAndSetParams(ovrModel, metadata)
+ ovrModel.set("classifier", classifier)
+ ovrModel
+ }
+ }
}
/**
@@ -158,7 +265,7 @@ final class OneVsRestModel private[ml] (
@Experimental
final class OneVsRest @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
- extends Estimator[OneVsRestModel] with OneVsRestParams {
+ extends Estimator[OneVsRestModel] with OneVsRestParams with MLWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("oneVsRest"))
@@ -186,12 +293,14 @@ final class OneVsRest @Since("1.4.0") (
validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
}
- @Since("1.4.0")
- override def fit(dataset: DataFrame): OneVsRestModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): OneVsRestModel = {
+ transformSchema(dataset.schema)
+
// determine number of classes either from metadata if provided, or via computation.
val labelSchema = dataset.schema($(labelCol))
val computeNumClasses: () => Int = () => {
- val Row(maxLabelIndex: Double) = dataset.agg(max($(labelCol))).head()
+ val Row(maxLabelIndex: Double) = dataset.agg(max(col($(labelCol)).cast(DoubleType))).head()
// classes are assumed to be numbered from 0,...,maxLabelIndex
maxLabelIndex.toInt + 1
}
@@ -243,4 +352,40 @@ final class OneVsRest @Since("1.4.0") (
}
copied
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new OneVsRest.OneVsRestWriter(this)
+}
+
+@Since("2.0.0")
+object OneVsRest extends MLReadable[OneVsRest] {
+
+ @Since("2.0.0")
+ override def read: MLReader[OneVsRest] = new OneVsRestReader
+
+ @Since("2.0.0")
+ override def load(path: String): OneVsRest = super.load(path)
+
+ /** [[MLWriter]] instance for [[OneVsRest]] */
+ private[OneVsRest] class OneVsRestWriter(instance: OneVsRest) extends MLWriter {
+
+ OneVsRestParams.validateParams(instance)
+
+ override protected def saveImpl(path: String): Unit = {
+ OneVsRestParams.saveImpl(path, instance, sc)
+ }
+ }
+
+ private class OneVsRestReader extends MLReader[OneVsRest] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[OneVsRest].getName
+
+ override def load(path: String): OneVsRest = {
+ val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className)
+ val ovr = new OneVsRest(metadata.uid)
+ DefaultParamsReader.getAndSetParams(ovr, metadata)
+ ovr.setClassifier(classifier)
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index 865614aa5c..d00fee12b0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors, VectorUDT}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, StructType}
@@ -95,7 +95,7 @@ abstract class ProbabilisticClassificationModel[
* @param dataset input dataset
* @return transformed dataset
*/
- override def transform(dataset: DataFrame): DataFrame = {
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
@@ -145,7 +145,7 @@ abstract class ProbabilisticClassificationModel[
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
" since no output columns were set.")
}
- outputData
+ outputData.toDF
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 5da04d341d..dfa711b243 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -17,17 +17,21 @@
package org.apache.spark.ml.classification
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
+
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
+import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.RandomForest
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
@@ -43,7 +47,7 @@ import org.apache.spark.sql.functions._
final class RandomForestClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
- with RandomForestParams with TreeClassifierParams {
+ with RandomForestClassifierParams with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("rfc"))
@@ -94,7 +98,7 @@ final class RandomForestClassifier @Since("1.4.0") (
override def setFeatureSubsetStrategy(value: String): this.type =
super.setFeatureSubsetStrategy(value)
- override protected def train(dataset: DataFrame): RandomForestClassificationModel = {
+ override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
@@ -120,7 +124,7 @@ final class RandomForestClassifier @Since("1.4.0") (
@Since("1.4.0")
@Experimental
-object RandomForestClassifier {
+object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifier] {
/** Accessor for supported impurity settings: entropy, gini */
@Since("1.4.0")
final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
@@ -129,6 +133,9 @@ object RandomForestClassifier {
@Since("1.4.0")
final val supportedFeatureSubsetStrategies: Array[String] =
RandomForestParams.supportedFeatureSubsetStrategies
+
+ @Since("2.0.0")
+ override def load(path: String): RandomForestClassifier = super.load(path)
}
/**
@@ -136,8 +143,9 @@ object RandomForestClassifier {
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for classification.
* It supports both binary and multiclass labels, as well as both continuous and categorical
* features.
+ *
* @param _trees Decision trees in the ensemble.
- * Warning: These have null parents.
+ * Warning: These have null parents.
*/
@Since("1.4.0")
@Experimental
@@ -147,12 +155,14 @@ final class RandomForestClassificationModel private[ml] (
@Since("1.6.0") override val numFeatures: Int,
@Since("1.5.0") override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
- with TreeEnsembleModel with Serializable {
+ with RandomForestClassificationModelParams with TreeEnsembleModel[DecisionTreeClassificationModel]
+ with MLWritable with Serializable {
- require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")
+ require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.")
/**
* Construct a random forest classification model, with all trees weighted equally.
+ *
* @param trees Component trees
*/
private[ml] def this(
@@ -162,15 +172,15 @@ final class RandomForestClassificationModel private[ml] (
this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses)
@Since("1.4.0")
- override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+ override def trees: Array[DecisionTreeClassificationModel] = _trees
// Note: We may add support for weights (based on tree performance) later on.
- private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0)
+ private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0)
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
- override protected def transformImpl(dataset: DataFrame): DataFrame = {
+ override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
@@ -208,6 +218,15 @@ final class RandomForestClassificationModel private[ml] (
}
}
+ /**
+ * Number of trees in ensemble
+ *
+ * @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0
+ */
+ // TODO: Once this is removed, then this class can inherit from RandomForestClassifierParams
+ @deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0")
+ val numTrees: Int = trees.length
+
@Since("1.4.0")
override def copy(extra: ParamMap): RandomForestClassificationModel = {
copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra)
@@ -216,36 +235,89 @@ final class RandomForestClassificationModel private[ml] (
@Since("1.4.0")
override def toString: String = {
- s"RandomForestClassificationModel (uid=$uid) with $numTrees trees"
+ s"RandomForestClassificationModel (uid=$uid) with $getNumTrees trees"
}
/**
* Estimate of the importance of each feature.
*
- * This generalizes the idea of "Gini" importance to other losses,
- * following the explanation of Gini importance from "Random Forests" documentation
- * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ * Each feature's importance is the average of its importance across all trees in the ensemble
+ * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ * and follows the implementation from scikit-learn.
*
- * This feature importance is calculated as follows:
- * - Average over trees:
- * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
- * where gain is scaled by the number of instances passing through node
- * - Normalize importances for tree to sum to 1.
- * - Normalize feature importance vector to sum to 1.
+ * @see [[DecisionTreeClassificationModel.featureImportances]]
*/
@Since("1.5.0")
- lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldRandomForestModel = {
new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld))
}
+
+ @Since("2.0.0")
+ override def write: MLWriter =
+ new RandomForestClassificationModel.RandomForestClassificationModelWriter(this)
}
-private[ml] object RandomForestClassificationModel {
+@Since("2.0.0")
+object RandomForestClassificationModel extends MLReadable[RandomForestClassificationModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[RandomForestClassificationModel] =
+ new RandomForestClassificationModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): RandomForestClassificationModel = super.load(path)
+
+ private[RandomForestClassificationModel]
+ class RandomForestClassificationModelWriter(instance: RandomForestClassificationModel)
+ extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ // Note: numTrees is not currently used, but could be nice to store for fast querying.
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures,
+ "numClasses" -> instance.numClasses,
+ "numTrees" -> instance.getNumTrees)
+ EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
+ }
+ }
+
+ private class RandomForestClassificationModelReader
+ extends MLReader[RandomForestClassificationModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[RandomForestClassificationModel].getName
+ private val treeClassName = classOf[DecisionTreeClassificationModel].getName
+
+ override def load(path: String): RandomForestClassificationModel = {
+ implicit val format = DefaultFormats
+ val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) =
+ EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val numClasses = (metadata.metadata \ "numClasses").extract[Int]
+ val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+
+ val trees: Array[DecisionTreeClassificationModel] = treesData.map {
+ case (treeMetadata, root) =>
+ val tree =
+ new DecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses)
+ DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+ tree
+ }
+ require(numTrees == trees.length, s"RandomForestClassificationModel.load expected $numTrees" +
+ s" trees based on metadata but found ${trees.length} trees.")
+
+ val model = new RandomForestClassificationModel(metadata.uid, trees, numFeatures, numClasses)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
- /** (private[ml]) Convert a model from the old API */
- def fromOld(
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
oldModel: OldRandomForestModel,
parent: RandomForestClassifier,
categoricalFeatures: Map[Int, Int],
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index f014a1d572..6cc9117da3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -17,15 +17,17 @@
package org.apache.spark.ml.clustering
+import org.apache.hadoop.fs.Path
+
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
+import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.clustering.
{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel}
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
@@ -49,7 +51,7 @@ private[clustering] trait BisectingKMeansParams extends Params
/** @group expertParam */
@Since("2.0.0")
- final val minDivisibleClusterSize = new Param[Double](
+ final val minDivisibleClusterSize = new DoubleParam(
this,
"minDivisibleClusterSize",
"the minimum number of points (if >= 1.0) or the minimum proportion",
@@ -81,7 +83,7 @@ private[clustering] trait BisectingKMeansParams extends Params
class BisectingKMeansModel private[ml] (
@Since("2.0.0") override val uid: String,
private val parentModel: MLlibBisectingKMeansModel
- ) extends Model[BisectingKMeansModel] with BisectingKMeansParams {
+ ) extends Model[BisectingKMeansModel] with BisectingKMeansParams with MLWritable {
@Since("2.0.0")
override def copy(extra: ParamMap): BisectingKMeansModel = {
@@ -90,7 +92,7 @@ class BisectingKMeansModel private[ml] (
}
@Since("2.0.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ override def transform(dataset: Dataset[_]): DataFrame = {
val predictUDF = udf((vector: Vector) => predict(vector))
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
@@ -110,11 +112,49 @@ class BisectingKMeansModel private[ml] (
* centers.
*/
@Since("2.0.0")
- def computeCost(dataset: DataFrame): Double = {
+ def computeCost(dataset: Dataset[_]): Double = {
SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
parentModel.computeCost(data)
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new BisectingKMeansModel.BisectingKMeansModelWriter(this)
+}
+
+object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] {
+ @Since("2.0.0")
+ override def read: MLReader[BisectingKMeansModel] = new BisectingKMeansModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): BisectingKMeansModel = super.load(path)
+
+ /** [[MLWriter]] instance for [[BisectingKMeansModel]] */
+ private[BisectingKMeansModel]
+ class BisectingKMeansModelWriter(instance: BisectingKMeansModel) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ val dataPath = new Path(path, "data").toString
+ instance.parentModel.save(sc, dataPath)
+ }
+ }
+
+ private class BisectingKMeansModelReader extends MLReader[BisectingKMeansModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[BisectingKMeansModel].getName
+
+ override def load(path: String): BisectingKMeansModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val dataPath = new Path(path, "data").toString
+ val mllibModel = MLlibBisectingKMeansModel.load(sc, dataPath)
+ val model = new BisectingKMeansModel(metadata.uid, mllibModel)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
}
/**
@@ -137,7 +177,7 @@ class BisectingKMeansModel private[ml] (
@Experimental
class BisectingKMeans @Since("2.0.0") (
@Since("2.0.0") override val uid: String)
- extends Estimator[BisectingKMeansModel] with BisectingKMeansParams {
+ extends Estimator[BisectingKMeansModel] with BisectingKMeansParams with DefaultParamsWritable {
setDefault(
k -> 4,
@@ -148,7 +188,7 @@ class BisectingKMeans @Since("2.0.0") (
override def copy(extra: ParamMap): BisectingKMeans = defaultCopy(extra)
@Since("2.0.0")
- def this() = this(Identifiable.randomUID("bisecting k-means"))
+ def this() = this(Identifiable.randomUID("bisecting-kmeans"))
/** @group setParam */
@Since("2.0.0")
@@ -175,7 +215,7 @@ class BisectingKMeans @Since("2.0.0") (
def setMinDivisibleClusterSize(value: Double): this.type = set(minDivisibleClusterSize, value)
@Since("2.0.0")
- override def fit(dataset: DataFrame): BisectingKMeansModel = {
+ override def fit(dataset: Dataset[_]): BisectingKMeansModel = {
val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
val bkm = new MLlibBisectingKMeans()
@@ -194,3 +234,10 @@ class BisectingKMeans @Since("2.0.0") (
}
}
+
+@Since("2.0.0")
+object BisectingKMeans extends DefaultParamsReadable[BisectingKMeans] {
+
+ @Since("2.0.0")
+ override def load(path: String): BisectingKMeans = super.load(path)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
new file mode 100644
index 0000000000..ead8ad7806
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -0,0 +1,311 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.clustering
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param.{IntParam, ParamMap, Params}
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util._
+import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM, GaussianMixtureModel => MLlibGMModel}
+import org.apache.spark.mllib.linalg._
+import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.types.{IntegerType, StructType}
+
+
+/**
+ * Common params for GaussianMixture and GaussianMixtureModel
+ */
+private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter with HasFeaturesCol
+ with HasSeed with HasPredictionCol with HasProbabilityCol with HasTol {
+
+ /**
+ * Set the number of clusters to create (k). Must be > 1. Default: 2.
+ * @group param
+ */
+ @Since("2.0.0")
+ final val k = new IntParam(this, "k", "number of clusters to create", (x: Int) => x > 1)
+
+ /** @group getParam */
+ @Since("2.0.0")
+ def getK: Int = $(k)
+
+ /**
+ * Validates and transforms the input schema.
+ * @param schema input schema
+ * @return output schema
+ */
+ protected def validateAndTransformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+ SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
+ SchemaUtils.appendColumn(schema, $(probabilityCol), new VectorUDT)
+ }
+}
+
+/**
+ * :: Experimental ::
+ * Model fitted by GaussianMixture.
+ * @param parentModel a model trained by spark.mllib.clustering.GaussianMixture.
+ */
+@Since("2.0.0")
+@Experimental
+class GaussianMixtureModel private[ml] (
+ @Since("2.0.0") override val uid: String,
+ private val parentModel: MLlibGMModel)
+ extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable {
+
+ @Since("2.0.0")
+ override def copy(extra: ParamMap): GaussianMixtureModel = {
+ val copied = new GaussianMixtureModel(uid, parentModel)
+ copyValues(copied, extra).setParent(this.parent)
+ }
+
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ val predUDF = udf((vector: Vector) => predict(vector))
+ val probUDF = udf((vector: Vector) => predictProbability(vector))
+ dataset.withColumn($(predictionCol), predUDF(col($(featuresCol))))
+ .withColumn($(probabilityCol), probUDF(col($(featuresCol))))
+ }
+
+ @Since("2.0.0")
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
+ }
+
+ private[clustering] def predict(features: Vector): Int = parentModel.predict(features)
+
+ private[clustering] def predictProbability(features: Vector): Vector = {
+ Vectors.dense(parentModel.predictSoft(features))
+ }
+
+ @Since("2.0.0")
+ def weights: Array[Double] = parentModel.weights
+
+ @Since("2.0.0")
+ def gaussians: Array[MultivariateGaussian] = parentModel.gaussians
+
+ @Since("2.0.0")
+ override def write: MLWriter = new GaussianMixtureModel.GaussianMixtureModelWriter(this)
+
+ private var trainingSummary: Option[GaussianMixtureSummary] = None
+
+ private[clustering] def setSummary(summary: GaussianMixtureSummary): this.type = {
+ this.trainingSummary = Some(summary)
+ this
+ }
+
+ /**
+ * Return true if there exists summary of model.
+ */
+ @Since("2.0.0")
+ def hasSummary: Boolean = trainingSummary.nonEmpty
+
+ /**
+ * Gets summary of model on training set. An exception is
+ * thrown if `trainingSummary == None`.
+ */
+ @Since("2.0.0")
+ def summary: GaussianMixtureSummary = trainingSummary.getOrElse {
+ throw new RuntimeException(
+ s"No training summary available for the ${this.getClass.getSimpleName}")
+ }
+}
+
+@Since("2.0.0")
+object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[GaussianMixtureModel] = new GaussianMixtureModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): GaussianMixtureModel = super.load(path)
+
+ /** [[MLWriter]] instance for [[GaussianMixtureModel]] */
+ private[GaussianMixtureModel] class GaussianMixtureModelWriter(
+ instance: GaussianMixtureModel) extends MLWriter {
+
+ private case class Data(weights: Array[Double], mus: Array[Vector], sigmas: Array[Matrix])
+
+ override protected def saveImpl(path: String): Unit = {
+ // Save metadata and Params
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ // Save model data: weights and gaussians
+ val weights = instance.weights
+ val gaussians = instance.gaussians
+ val mus = gaussians.map(_.mu)
+ val sigmas = gaussians.map(_.sigma)
+ val data = Data(weights, mus, sigmas)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class GaussianMixtureModelReader extends MLReader[GaussianMixtureModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[GaussianMixtureModel].getName
+
+ override def load(path: String): GaussianMixtureModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+ val dataPath = new Path(path, "data").toString
+ val row = sqlContext.read.parquet(dataPath).select("weights", "mus", "sigmas").head()
+ val weights = row.getSeq[Double](0).toArray
+ val mus = row.getSeq[Vector](1).toArray
+ val sigmas = row.getSeq[Matrix](2).toArray
+ require(mus.length == sigmas.length, "Length of Mu and Sigma array must match")
+ require(mus.length == weights.length, "Length of weight and Gaussian array must match")
+
+ val gaussians = (mus zip sigmas).map {
+ case (mu, sigma) =>
+ new MultivariateGaussian(mu, sigma)
+ }
+ val model = new GaussianMixtureModel(metadata.uid, new MLlibGMModel(weights, gaussians))
+
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
+}
+
+/**
+ * :: Experimental ::
+ * GaussianMixture clustering.
+ */
+@Since("2.0.0")
+@Experimental
+class GaussianMixture @Since("2.0.0") (
+ @Since("2.0.0") override val uid: String)
+ extends Estimator[GaussianMixtureModel] with GaussianMixtureParams with DefaultParamsWritable {
+
+ setDefault(
+ k -> 2,
+ maxIter -> 100,
+ tol -> 0.01)
+
+ @Since("2.0.0")
+ override def copy(extra: ParamMap): GaussianMixture = defaultCopy(extra)
+
+ @Since("2.0.0")
+ def this() = this(Identifiable.randomUID("GaussianMixture"))
+
+ /** @group setParam */
+ @Since("2.0.0")
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group setParam */
+ @Since("2.0.0")
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ /** @group setParam */
+ @Since("2.0.0")
+ def setProbabilityCol(value: String): this.type = set(probabilityCol, value)
+
+ /** @group setParam */
+ @Since("2.0.0")
+ def setK(value: Int): this.type = set(k, value)
+
+ /** @group setParam */
+ @Since("2.0.0")
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ /** @group setParam */
+ @Since("2.0.0")
+ def setTol(value: Double): this.type = set(tol, value)
+
+ /** @group setParam */
+ @Since("2.0.0")
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): GaussianMixtureModel = {
+ val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
+
+ val algo = new MLlibGM()
+ .setK($(k))
+ .setMaxIterations($(maxIter))
+ .setSeed($(seed))
+ .setConvergenceTol($(tol))
+ val parentModel = algo.run(rdd)
+ val model = copyValues(new GaussianMixtureModel(uid, parentModel).setParent(this))
+ val summary = new GaussianMixtureSummary(model.transform(dataset),
+ $(predictionCol), $(probabilityCol), $(featuresCol), $(k))
+ model.setSummary(summary)
+ }
+
+ @Since("2.0.0")
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
+ }
+}
+
+@Since("2.0.0")
+object GaussianMixture extends DefaultParamsReadable[GaussianMixture] {
+
+ @Since("2.0.0")
+ override def load(path: String): GaussianMixture = super.load(path)
+}
+
+/**
+ * :: Experimental ::
+ * Summary of GaussianMixture.
+ *
+ * @param predictions [[DataFrame]] produced by [[GaussianMixtureModel.transform()]]
+ * @param predictionCol Name for column of predicted clusters in `predictions`
+ * @param probabilityCol Name for column of predicted probability of each cluster in `predictions`
+ * @param featuresCol Name for column of features in `predictions`
+ * @param k Number of clusters
+ */
+@Since("2.0.0")
+@Experimental
+class GaussianMixtureSummary private[clustering] (
+ @Since("2.0.0") @transient val predictions: DataFrame,
+ @Since("2.0.0") val predictionCol: String,
+ @Since("2.0.0") val probabilityCol: String,
+ @Since("2.0.0") val featuresCol: String,
+ @Since("2.0.0") val k: Int) extends Serializable {
+
+ /**
+ * Cluster centers of the transformed data.
+ */
+ @Since("2.0.0")
+ @transient lazy val cluster: DataFrame = predictions.select(predictionCol)
+
+ /**
+ * Probability of each cluster.
+ */
+ @Since("2.0.0")
+ @transient lazy val probability: DataFrame = predictions.select(probabilityCol)
+
+ /**
+ * Size of (number of data points in) each cluster.
+ */
+ @Since("2.0.0")
+ lazy val clusterSizes: Array[Long] = {
+ val sizes = Array.fill[Long](k)(0)
+ cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach {
+ case Row(cluster: Int, count: Long) => sizes(cluster) = count
+ }
+ sizes
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 38428826a8..b324196842 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
@@ -105,8 +105,8 @@ class KMeansModel private[ml] (
copyValues(copied, extra)
}
- @Since("1.5.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
val predictUDF = udf((vector: Vector) => predict(vector))
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
@@ -126,8 +126,8 @@ class KMeansModel private[ml] (
* model on the given data.
*/
// TODO: Replace the temp fix when we have proper evaluators defined for clustering.
- @Since("1.6.0")
- def computeCost(dataset: DataFrame): Double = {
+ @Since("2.0.0")
+ def computeCost(dataset: Dataset[_]): Double = {
SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
parentModel.computeCost(data)
@@ -144,6 +144,12 @@ class KMeansModel private[ml] (
}
/**
+ * Return true if there exists summary of model.
+ */
+ @Since("2.0.0")
+ def hasSummary: Boolean = trainingSummary.nonEmpty
+
+ /**
* Gets summary of model on training set. An exception is
* thrown if `trainingSummary == None`.
*/
@@ -254,8 +260,8 @@ class KMeans @Since("1.5.0") (
@Since("1.5.0")
def setSeed(value: Long): this.type = set(seed, value)
- @Since("1.5.0")
- override def fit(dataset: DataFrame): KMeansModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): KMeansModel = {
val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point }
val algo = new MLlibKMeans()
@@ -267,7 +273,8 @@ class KMeans @Since("1.5.0") (
.setEpsilon($(tol))
val parentModel = algo.run(rdd)
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
- val summary = new KMeansSummary(model.transform(dataset), $(predictionCol), $(featuresCol))
+ val summary = new KMeansSummary(
+ model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
model.setSummary(summary)
}
@@ -284,10 +291,22 @@ object KMeans extends DefaultParamsReadable[KMeans] {
override def load(path: String): KMeans = super.load(path)
}
+/**
+ * :: Experimental ::
+ * Summary of KMeans.
+ *
+ * @param predictions [[DataFrame]] produced by [[KMeansModel.transform()]]
+ * @param predictionCol Name for column of predicted clusters in `predictions`
+ * @param featuresCol Name for column of features in `predictions`
+ * @param k Number of clusters
+ */
+@Since("2.0.0")
+@Experimental
class KMeansSummary private[clustering] (
@Since("2.0.0") @transient val predictions: DataFrame,
@Since("2.0.0") val predictionCol: String,
- @Since("2.0.0") val featuresCol: String) extends Serializable {
+ @Since("2.0.0") val featuresCol: String,
+ @Since("2.0.0") val k: Int) extends Serializable {
/**
* Cluster centers of the transformed data.
@@ -296,10 +315,15 @@ class KMeansSummary private[clustering] (
@transient lazy val cluster: DataFrame = predictions.select(predictionCol)
/**
- * Size of each cluster.
+ * Size of (number of data points in) each cluster.
*/
@Since("2.0.0")
- lazy val size: Array[Int] = cluster.rdd.map {
- case Row(clusterIdx: Int) => (clusterIdx, 1)
- }.reduceByKey(_ + _).collect().sortBy(_._1).map(_._2)
+ lazy val clusterSizes: Array[Long] = {
+ val sizes = Array.fill[Long](k)(0)
+ cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach {
+ case Row(cluster: Int, count: Long) => sizes(cluster) = count
+ }
+ sizes
+ }
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index fe6a37fd6d..c57ceba4a9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -17,21 +17,22 @@
package org.apache.spark.ml.clustering
-import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.{FileSystem, Path}
-import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasMaxIter, HasSeed}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel,
- EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
- LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
- OnlineLDAOptimizer => OldOnlineLDAOptimizer}
+ EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
+ LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
+ OnlineLDAOptimizer => OldOnlineLDAOptimizer}
+import org.apache.spark.mllib.impl.PeriodicCheckpointer
import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors, VectorUDT}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
import org.apache.spark.sql.functions.{col, monotonicallyIncreasingId, udf}
import org.apache.spark.sql.types.StructType
@@ -41,6 +42,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
/**
* Param for the number of topics (clusters) to infer. Must be > 1. Default: 10.
+ *
* @group param
*/
@Since("1.6.0")
@@ -173,10 +175,11 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* This uses a variational approximation following Hoffman et al. (2010), where the approximate
* distribution is called "gamma." Technically, this method returns this approximation "gamma"
* for each document.
+ *
* @group param
*/
@Since("1.6.0")
- final val topicDistributionCol = new Param[String](this, "topicDistribution", "Output column" +
+ final val topicDistributionCol = new Param[String](this, "topicDistributionCol", "Output column" +
" with estimates of the topic mixture distribution for each document (often called \"theta\"" +
" in the literature). Returns a vector of zeros for an empty document.")
@@ -187,15 +190,19 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
def getTopicDistributionCol: String = $(topicDistributionCol)
/**
+ * For Online optimizer only: [[optimizer]] = "online".
+ *
* A (positive) learning parameter that downweights early iterations. Larger values make early
* iterations count less.
* This is called "tau0" in the Online LDA paper (Hoffman et al., 2010)
* Default: 1024, following Hoffman et al.
+ *
* @group expertParam
*/
@Since("1.6.0")
- final val learningOffset = new DoubleParam(this, "learningOffset", "A (positive) learning" +
- " parameter that downweights early iterations. Larger values make early iterations count less.",
+ final val learningOffset = new DoubleParam(this, "learningOffset", "(For online optimizer)" +
+ " A (positive) learning parameter that downweights early iterations. Larger values make early" +
+ " iterations count less.",
ParamValidators.gt(0))
/** @group expertGetParam */
@@ -203,22 +210,27 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
def getLearningOffset: Double = $(learningOffset)
/**
+ * For Online optimizer only: [[optimizer]] = "online".
+ *
* Learning rate, set as an exponential decay rate.
* This should be between (0.5, 1.0] to guarantee asymptotic convergence.
* This is called "kappa" in the Online LDA paper (Hoffman et al., 2010).
* Default: 0.51, based on Hoffman et al.
+ *
* @group expertParam
*/
@Since("1.6.0")
- final val learningDecay = new DoubleParam(this, "learningDecay", "Learning rate, set as an" +
- " exponential decay rate. This should be between (0.5, 1.0] to guarantee asymptotic" +
- " convergence.", ParamValidators.gt(0))
+ final val learningDecay = new DoubleParam(this, "learningDecay", "(For online optimizer)" +
+ " Learning rate, set as an exponential decay rate. This should be between (0.5, 1.0] to" +
+ " guarantee asymptotic convergence.", ParamValidators.gt(0))
/** @group expertGetParam */
@Since("1.6.0")
def getLearningDecay: Double = $(learningDecay)
/**
+ * For Online optimizer only: [[optimizer]] = "online".
+ *
* Fraction of the corpus to be sampled and used in each iteration of mini-batch gradient descent,
* in range (0, 1].
*
@@ -230,11 +242,13 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* [[org.apache.spark.mllib.clustering.OnlineLDAOptimizer]].
*
* Default: 0.05, i.e., 5% of total documents.
+ *
* @group param
*/
@Since("1.6.0")
- final val subsamplingRate = new DoubleParam(this, "subsamplingRate", "Fraction of the corpus" +
- " to be sampled and used in each iteration of mini-batch gradient descent, in range (0, 1].",
+ final val subsamplingRate = new DoubleParam(this, "subsamplingRate", "(For online optimizer)" +
+ " Fraction of the corpus to be sampled and used in each iteration of mini-batch" +
+ " gradient descent, in range (0, 1].",
ParamValidators.inRange(0.0, 1.0, lowerInclusive = false, upperInclusive = true))
/** @group getParam */
@@ -242,23 +256,52 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
def getSubsamplingRate: Double = $(subsamplingRate)
/**
+ * For Online optimizer only (currently): [[optimizer]] = "online".
+ *
* Indicates whether the docConcentration (Dirichlet parameter for
* document-topic distribution) will be optimized during training.
* Setting this to true will make the model more expressive and fit the training data better.
* Default: false
+ *
* @group expertParam
*/
@Since("1.6.0")
final val optimizeDocConcentration = new BooleanParam(this, "optimizeDocConcentration",
- "Indicates whether the docConcentration (Dirichlet parameter for document-topic" +
- " distribution) will be optimized during training.")
+ "(For online optimizer only, currently) Indicates whether the docConcentration" +
+ " (Dirichlet parameter for document-topic distribution) will be optimized during training.")
/** @group expertGetParam */
@Since("1.6.0")
def getOptimizeDocConcentration: Boolean = $(optimizeDocConcentration)
/**
+ * For EM optimizer only: [[optimizer]] = "em".
+ *
+ * If using checkpointing, this indicates whether to keep the last
+ * checkpoint. If false, then the checkpoint will be deleted. Deleting the checkpoint can
+ * cause failures if a data partition is lost, so set this bit with care.
+ * Note that checkpoints will be cleaned up via reference counting, regardless.
+ *
+ * See [[DistributedLDAModel.getCheckpointFiles]] for getting remaining checkpoints and
+ * [[DistributedLDAModel.deleteCheckpointFiles]] for removing remaining checkpoints.
+ *
+ * Default: true
+ *
+ * @group expertParam
+ */
+ @Since("2.0.0")
+ final val keepLastCheckpoint = new BooleanParam(this, "keepLastCheckpoint",
+ "(For EM optimizer) If using checkpointing, this indicates whether to keep the last" +
+ " checkpoint. If false, then the checkpoint will be deleted. Deleting the checkpoint can" +
+ " cause failures if a data partition is lost, so set this bit with care.")
+
+ /** @group expertGetParam */
+ @Since("2.0.0")
+ def getKeepLastCheckpoint: Boolean = $(keepLastCheckpoint)
+
+ /**
* Validates and transforms the input schema.
+ *
* @param schema input schema
* @return output schema
*/
@@ -303,6 +346,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
.setOptimizeDocConcentration($(optimizeDocConcentration))
case "em" =>
new OldEMLDAOptimizer()
+ .setKeepLastCheckpoint($(keepLastCheckpoint))
}
}
@@ -341,6 +385,7 @@ sealed abstract class LDAModel private[ml] (
/**
* The features for LDA should be a [[Vector]] representing the word counts in a document.
* The vector should be of length vocabSize, with counts for each term (word).
+ *
* @group setParam
*/
@Since("1.6.0")
@@ -357,15 +402,15 @@ sealed abstract class LDAModel private[ml] (
* is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver.
* This implementation may be changed in the future.
*/
- @Since("1.6.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
if ($(topicDistributionCol).nonEmpty) {
val t = udf(oldLocalModel.getTopicDistributionMethod(sqlContext.sparkContext))
- dataset.withColumn($(topicDistributionCol), t(col($(featuresCol))))
+ dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF
} else {
logWarning("LDAModel.transform was called without any output columns. Set an output column" +
" such as topicDistributionCol to produce results.")
- dataset
+ dataset.toDF
}
}
@@ -410,8 +455,8 @@ sealed abstract class LDAModel private[ml] (
* @param dataset test corpus to use for calculating log likelihood
* @return variational lower bound on the log likelihood of the entire corpus
*/
- @Since("1.6.0")
- def logLikelihood(dataset: DataFrame): Double = {
+ @Since("2.0.0")
+ def logLikelihood(dataset: Dataset[_]): Double = {
val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
oldLocalModel.logLikelihood(oldDataset)
}
@@ -427,8 +472,8 @@ sealed abstract class LDAModel private[ml] (
* @param dataset test corpus to use for calculating perplexity
* @return Variational upper bound on log perplexity per token.
*/
- @Since("1.6.0")
- def logPerplexity(dataset: DataFrame): Double = {
+ @Since("2.0.0")
+ def logPerplexity(dataset: Dataset[_]): Double = {
val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
oldLocalModel.logPerplexity(oldDataset)
}
@@ -619,6 +664,35 @@ class DistributedLDAModel private[ml] (
@Since("1.6.0")
lazy val logPrior: Double = oldDistributedModel.logPrior
+ private var _checkpointFiles: Array[String] = oldDistributedModel.checkpointFiles
+
+ /**
+ * If using checkpointing and [[LDA.keepLastCheckpoint]] is set to true, then there may be
+ * saved checkpoint files. This method is provided so that users can manage those files.
+ *
+ * Note that removing the checkpoints can cause failures if a partition is lost and is needed
+ * by certain [[DistributedLDAModel]] methods. Reference counting will clean up the checkpoints
+ * when this model and derivative data go out of scope.
+ *
+ * @return Checkpoint files from training
+ */
+ @DeveloperApi
+ @Since("2.0.0")
+ def getCheckpointFiles: Array[String] = _checkpointFiles
+
+ /**
+ * Remove any remaining checkpoint files from training.
+ *
+ * @see [[getCheckpointFiles]]
+ */
+ @DeveloperApi
+ @Since("2.0.0")
+ def deleteCheckpointFiles(): Unit = {
+ val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration)
+ _checkpointFiles.foreach(PeriodicCheckpointer.removeCheckpointFile(_, fs))
+ _checkpointFiles = Array.empty[String]
+ }
+
@Since("1.6.0")
override def write: MLWriter = new DistributedLDAModel.DistributedWriter(this)
}
@@ -696,11 +770,12 @@ class LDA @Since("1.6.0") (
setDefault(maxIter -> 20, k -> 10, optimizer -> "online", checkpointInterval -> 10,
learningOffset -> 1024, learningDecay -> 0.51, subsamplingRate -> 0.05,
- optimizeDocConcentration -> true)
+ optimizeDocConcentration -> true, keepLastCheckpoint -> true)
/**
* The features for LDA should be a [[Vector]] representing the word counts in a document.
* The vector should be of length vocabSize, with counts for each term (word).
+ *
* @group setParam
*/
@Since("1.6.0")
@@ -758,11 +833,15 @@ class LDA @Since("1.6.0") (
@Since("1.6.0")
def setOptimizeDocConcentration(value: Boolean): this.type = set(optimizeDocConcentration, value)
+ /** @group expertSetParam */
+ @Since("2.0.0")
+ def setKeepLastCheckpoint(value: Boolean): this.type = set(keepLastCheckpoint, value)
+
@Since("1.6.0")
override def copy(extra: ParamMap): LDA = defaultCopy(extra)
- @Since("1.6.0")
- override def fit(dataset: DataFrame): LDAModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): LDAModel = {
transformSchema(dataset.schema, logging = true)
val oldLDA = new OldLDA()
.setK($(k))
@@ -794,7 +873,7 @@ class LDA @Since("1.6.0") (
private[clustering] object LDA extends DefaultParamsReadable[LDA] {
/** Get dataset for spark.mllib LDA */
- def getOldDataset(dataset: DataFrame, featuresCol: String): RDD[(Long, Vector)] = {
+ def getOldDataset(dataset: Dataset[_], featuresCol: String): RDD[(Long, Vector)] = {
dataset
.withColumn("docId", monotonicallyIncreasingId())
.select("docId", featuresCol)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index 337ffbe90f..bde8c275fd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.types.DoubleType
/**
@@ -69,8 +69,8 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
setDefault(metricName -> "areaUnderROC")
- @Since("1.2.0")
- override def evaluate(dataset: DataFrame): Double = {
+ @Since("2.0.0")
+ override def evaluate(dataset: Dataset[_]): Double = {
val schema = dataset.schema
SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT))
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
index 0f22cca3a7..5f765c071b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.ml.param.{ParamMap, Params}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.Dataset
/**
* :: DeveloperApi ::
@@ -36,8 +36,8 @@ abstract class Evaluator extends Params {
* @param paramMap parameter map that specifies the input columns and output metrics
* @return metric
*/
- @Since("1.5.0")
- def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = {
+ @Since("2.0.0")
+ def evaluate(dataset: Dataset[_], paramMap: ParamMap): Double = {
this.copy(paramMap).evaluate(dataset)
}
@@ -46,8 +46,8 @@ abstract class Evaluator extends Params {
* @param dataset a dataset that contains labels/observations and predictions.
* @return metric
*/
- @Since("1.5.0")
- def evaluate(dataset: DataFrame): Double
+ @Since("2.0.0")
+ def evaluate(dataset: Dataset[_]): Double
/**
* Indicates whether the metric returned by [[evaluate()]] should be maximized (true, default)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
index 55ff44323a..3acfc221c9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.types.DoubleType
/**
@@ -68,8 +68,8 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
setDefault(metricName -> "f1")
- @Since("1.5.0")
- override def evaluate(dataset: DataFrame): Double = {
+ @Since("2.0.0")
+ override def evaluate(dataset: Dataset[_]): Double = {
val schema = dataset.schema
SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index 9976d7ed43..ed04b67bcc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.mllib.evaluation.RegressionMetrics
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, FloatType}
@@ -39,11 +39,12 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
def this() = this(Identifiable.randomUID("regEval"))
/**
- * param for metric name in evaluation (supports `"rmse"` (default), `"mse"`, `"r2"`, and `"mae"`)
+ * Param for metric name in evaluation. Supports:
+ * - `"rmse"` (default): root mean squared error
+ * - `"mse"`: mean squared error
+ * - `"r2"`: R^2^ metric
+ * - `"mae"`: mean absolute error
*
- * Because we will maximize evaluation value (ref: `CrossValidator`),
- * when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`),
- * we take and output the negative of this metric.
* @group param
*/
@Since("1.4.0")
@@ -70,8 +71,8 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
setDefault(metricName -> "rmse")
- @Since("1.4.0")
- override def evaluate(dataset: DataFrame): Double = {
+ @Since("2.0.0")
+ override def evaluate(dataset: Dataset[_]): Double = {
val schema = dataset.schema
val predictionColName = $(predictionCol)
val predictionType = schema($(predictionCol)).dataType
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
index 2f8e3a0371..898ac2cc89 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -64,7 +64,8 @@ final class Binarizer(override val uid: String)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
val outputSchema = transformSchema(dataset.schema, logging = true)
val schema = dataset.schema
val inputType = schema($(inputCol)).dataType
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index 33abc7c99d..10e622ace6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -68,7 +68,8 @@ final class Bucketizer(override val uid: String)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema)
val bucketizer = udf { feature: Double =>
Bucketizer.binarySearchForBuckets($(splits), feature)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index b9e9d56853..cfecae7e0b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -77,7 +77,8 @@ final class ChiSqSelector(override val uid: String)
/** @group setParam */
def setLabelCol(value: String): this.type = set(labelCol, value)
- override def fit(dataset: DataFrame): ChiSqSelectorModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): ChiSqSelectorModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(labelCol), $(featuresCol)).rdd.map {
case Row(label: Double, features: Vector) =>
@@ -127,7 +128,8 @@ final class ChiSqSelectorModel private[ml] (
/** @group setParam */
def setLabelCol(value: String): this.type = set(labelCol, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
val transformedSchema = transformSchema(dataset.schema, logging = true)
val newField = transformedSchema.last
val selector = udf { chiSqSelector.transform _ }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index 5694b3890f..922670a41b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -26,7 +26,7 @@ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vectors, VectorUDT}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashMap
@@ -100,6 +100,21 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit
/** @group getParam */
def getMinTF: Double = $(minTF)
+
+ /**
+ * Binary toggle to control the output vector values.
+ * If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful for
+ * discrete probabilistic models that model binary events rather than integer counts.
+ * Default: false
+ * @group param
+ */
+ val binary: BooleanParam =
+ new BooleanParam(this, "binary", "If True, all non zero counts are set to 1.")
+
+ /** @group getParam */
+ def getBinary: Boolean = $(binary)
+
+ setDefault(binary -> false)
}
/**
@@ -127,9 +142,13 @@ class CountVectorizer(override val uid: String)
/** @group setParam */
def setMinTF(value: Double): this.type = set(minTF, value)
+ /** @group setParam */
+ def setBinary(value: Boolean): this.type = set(binary, value)
+
setDefault(vocabSize -> (1 << 18), minDF -> 1)
- override def fit(dataset: DataFrame): CountVectorizerModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): CountVectorizerModel = {
transformSchema(dataset.schema, logging = true)
val vocSize = $(vocabSize)
val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0))
@@ -152,16 +171,10 @@ class CountVectorizer(override val uid: String)
(word, count)
}.cache()
val fullVocabSize = wordCounts.count()
- val vocab: Array[String] = {
- val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocSize) {
- // Use all terms
- wordCounts.collect().sortBy(-_._2)
- } else {
- // Sort terms to select vocab
- wordCounts.sortBy(_._2, ascending = false).take(vocSize)
- }
- tmpSortedWC.map(_._1)
- }
+
+ val vocab = wordCounts
+ .top(math.min(fullVocabSize, vocSize).toInt)(Ordering.by(_._2))
+ .map(_._1)
require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.")
copyValues(new CountVectorizerModel(uid, vocab).setParent(this))
@@ -206,30 +219,14 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin
/** @group setParam */
def setMinTF(value: Double): this.type = set(minTF, value)
- /**
- * Binary toggle to control the output vector values.
- * If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful for
- * discrete probabilistic models that model binary events rather than integer counts.
- * Default: false
- * @group param
- */
- val binary: BooleanParam =
- new BooleanParam(this, "binary", "If True, all non zero counts are set to 1. " +
- "This is useful for discrete probabilistic models that model binary events rather " +
- "than integer counts")
-
- /** @group getParam */
- def getBinary: Boolean = $(binary)
-
/** @group setParam */
def setBinary(value: Boolean): this.type = set(binary, value)
- setDefault(binary -> false)
-
/** Dictionary created from [[vocabulary]] and its indices, broadcast once for [[transform()]] */
private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
if (broadcastDict.isEmpty) {
val dict = vocabulary.zipWithIndex.toMap
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
index 2c7ffdb7ba..1b0a9a12e8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
@@ -38,9 +38,9 @@ class ElementwiseProduct(override val uid: String)
def this() = this(Identifiable.randomUID("elemProd"))
/**
- * the vector to multiply with input vectors
- * @group param
- */
+ * the vector to multiply with input vectors
+ * @group param
+ */
val scalingVec: Param[Vector] = new Param(this, "scalingVec", "vector for hadamard product")
/** @group setParam */
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index 61a78d73c4..467ad73074 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -20,11 +20,11 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
-import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
+import org.apache.spark.ml.param.{BooleanParam, IntParam, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{ArrayType, StructType}
@@ -52,7 +52,18 @@ class HashingTF(override val uid: String)
val numFeatures = new IntParam(this, "numFeatures", "number of features (> 0)",
ParamValidators.gt(0))
- setDefault(numFeatures -> (1 << 18))
+ /**
+ * Binary toggle to control term frequency counts.
+ * If true, all non-zero counts are set to 1. This is useful for discrete probabilistic
+ * models that model binary events rather than integer counts.
+ * (default = false)
+ * @group param
+ */
+ val binary = new BooleanParam(this, "binary", "If true, all non zero counts are set to 1. " +
+ "This is useful for discrete probabilistic models that model binary events rather " +
+ "than integer counts")
+
+ setDefault(numFeatures -> (1 << 18), binary -> false)
/** @group getParam */
def getNumFeatures: Int = $(numFeatures)
@@ -60,9 +71,16 @@ class HashingTF(override val uid: String)
/** @group setParam */
def setNumFeatures(value: Int): this.type = set(numFeatures, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ /** @group getParam */
+ def getBinary: Boolean = $(binary)
+
+ /** @group setParam */
+ def setBinary(value: Boolean): this.type = set(binary, value)
+
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
val outputSchema = transformSchema(dataset.schema)
- val hashingTF = new feature.HashingTF($(numFeatures))
+ val hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary))
val t = udf { terms: Seq[_] => hashingTF.transform(terms) }
val metadata = outputSchema($(outputCol)).metadata
dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index f36cf503a0..5075b78c98 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -76,7 +76,8 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa
/** @group setParam */
def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)
- override def fit(dataset: DataFrame): IDFModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): IDFModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v }
val idf = new feature.IDF($(minDocFreq)).fit(input)
@@ -115,7 +116,8 @@ class IDFModel private[ml] (
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val idf = udf { vec: Vector => idfModel.transform(vec) }
dataset.withColumn($(outputCol), idf(col($(inputCol))))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
index d3fe6e528f..9ca34e9ae2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.ml.Transformer
import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
@@ -68,8 +68,8 @@ class Interaction @Since("1.6.0") (override val uid: String) extends Transformer
StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false))
}
- @Since("1.6.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
val inputFeatures = $(inputCols).map(c => dataset.schema(c))
val featureEncoders = getFeatureEncoders(inputFeatures)
val featureAttrs = getFeatureAttrs(inputFeatures)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
index 7de5a4d5d3..e9df600c8a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
@@ -66,7 +66,8 @@ class MaxAbsScaler @Since("2.0.0") (override val uid: String)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def fit(dataset: DataFrame): MaxAbsScalerModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): MaxAbsScalerModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v }
val summary = Statistics.colStats(input)
@@ -111,7 +112,8 @@ class MaxAbsScalerModel private[ml] (
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
// TODO: this looks hack, we may have to handle sparse and dense vectors separately.
val maxAbsUnzero = Vectors.dense(maxAbs.toArray.map(x => if (x == 0) 1 else x))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index b13684a1cb..125becbb8a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -103,7 +103,8 @@ class MinMaxScaler(override val uid: String)
/** @group setParam */
def setMax(value: Double): this.type = set(max, value)
- override def fit(dataset: DataFrame): MinMaxScalerModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): MinMaxScalerModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v }
val summary = Statistics.colStats(input)
@@ -154,7 +155,8 @@ class MinMaxScalerModel private[ml] (
/** @group setParam */
def setMax(value: Double): this.type = set(max, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
val originalRange = (originalMax.toBreeze - originalMin.toBreeze).toArray
val minArray = originalMin.toArray
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index 4f67042629..99357793db 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{DoubleType, NumericType, StructType}
@@ -121,7 +121,8 @@ class OneHotEncoder(override val uid: String) extends Transformer
StructType(outputFields)
}
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
// schema transformation
val inputColName = $(inputCol)
val outputColName = $(outputCol)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index 305c3d187f..9cf722e121 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -68,7 +68,8 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
/**
* Computes a [[PCAModel]] that contains the principal components of the input vectors.
*/
- override def fit(dataset: DataFrame): PCAModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): PCAModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v}
val pca = new feature.PCA(k = $(k))
@@ -124,7 +125,8 @@ class PCAModel private[ml] (
* NOTE: Vectors to be transformed must be the same length
* as the source vectors given to [[PCA.fit()]].
*/
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val pcaModel = new feature.PCAModel($(k), pc, explainedVariance)
val pcaOp = udf { pcaModel.transform _ }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index e486e92c12..5c7993af64 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -23,10 +23,10 @@ import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml._
import org.apache.spark.ml.attribute.NominalAttribute
-import org.apache.spark.ml.param.{IntParam, _}
+import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol, HasSeed}
import org.apache.spark.ml.util._
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.types.{DoubleType, StructType}
import org.apache.spark.util.random.XORShiftRandom
@@ -37,7 +37,7 @@ private[feature] trait QuantileDiscretizerBase extends Params
with HasInputCol with HasOutputCol with HasSeed {
/**
- * Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must
+ * Number of buckets (quantiles, or categories) into which data points are grouped. Must
* be >= 2.
* default: 2
* @group param
@@ -49,6 +49,21 @@ private[feature] trait QuantileDiscretizerBase extends Params
/** @group getParam */
def getNumBuckets: Int = getOrDefault(numBuckets)
+
+ /**
+ * Relative error (see documentation for
+ * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] for description)
+ * Must be a number in [0, 1].
+ * default: 0.001
+ * @group param
+ */
+ val relativeError = new DoubleParam(this, "relativeError", "The relative target precision " +
+ "for approxQuantile",
+ ParamValidators.inRange(0.0, 1.0))
+ setDefault(relativeError -> 0.001)
+
+ /** @group getParam */
+ def getRelativeError: Double = getOrDefault(relativeError)
}
/**
@@ -56,8 +71,7 @@ private[feature] trait QuantileDiscretizerBase extends Params
* `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned
* categorical features. The bin ranges are chosen by taking a sample of the data and dividing it
* into roughly equal parts. The lower and upper bin bounds will be -Infinity and +Infinity,
- * covering all real values. This attempts to find numBuckets partitions based on a sample of data,
- * but it may find fewer depending on the data sample values.
+ * covering all real values.
*/
@Experimental
final class QuantileDiscretizer(override val uid: String)
@@ -66,6 +80,9 @@ final class QuantileDiscretizer(override val uid: String)
def this() = this(Identifiable.randomUID("quantileDiscretizer"))
/** @group setParam */
+ def setRelativeError(value: Double): this.type = set(relativeError, value)
+
+ /** @group setParam */
def setNumBuckets(value: Int): this.type = set(numBuckets, value)
/** @group setParam */
@@ -87,12 +104,13 @@ final class QuantileDiscretizer(override val uid: String)
StructType(outputFields)
}
- override def fit(dataset: DataFrame): Bucketizer = {
- val samples = QuantileDiscretizer
- .getSampledInput(dataset.select($(inputCol)), $(numBuckets), $(seed))
- .map { case Row(feature: Double) => feature }
- val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1)
- val splits = QuantileDiscretizer.getSplits(candidates)
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): Bucketizer = {
+ val splits = dataset.stat.approxQuantile($(inputCol),
+ (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError))
+ splits(0) = Double.NegativeInfinity
+ splits(splits.length - 1) = Double.PositiveInfinity
+
val bucketizer = new Bucketizer(uid).setSplits(splits)
copyValues(bucketizer.setParent(this))
}
@@ -103,90 +121,6 @@ final class QuantileDiscretizer(override val uid: String)
@Since("1.6.0")
object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging {
- /**
- * Minimum number of samples required for finding splits, regardless of number of bins. If
- * the dataset has fewer rows than this value, the entire dataset will be used.
- */
- private[spark] val minSamplesRequired: Int = 10000
-
- /**
- * Sampling from the given dataset to collect quantile statistics.
- */
- private[feature] def getSampledInput(dataset: DataFrame, numBins: Int, seed: Long): Array[Row] = {
- val totalSamples = dataset.count()
- require(totalSamples > 0,
- "QuantileDiscretizer requires non-empty input dataset but was given an empty input.")
- val requiredSamples = math.max(numBins * numBins, minSamplesRequired)
- val fraction = math.min(requiredSamples.toDouble / totalSamples, 1.0)
- dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect()
- }
-
- /**
- * Compute split points with respect to the sample distribution.
- */
- private[feature]
- def findSplitCandidates(samples: Array[Double], numSplits: Int): Array[Double] = {
- val valueCountMap = samples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
- m + ((x, m.getOrElse(x, 0) + 1))
- }
- val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray ++ Array((Double.MaxValue, 1))
- val possibleSplits = valueCounts.length - 1
- if (possibleSplits <= numSplits) {
- valueCounts.dropRight(1).map(_._1)
- } else {
- val stride: Double = math.ceil(samples.length.toDouble / (numSplits + 1))
- val splitsBuilder = mutable.ArrayBuilder.make[Double]
- var index = 1
- // currentCount: sum of counts of values that have been visited
- var currentCount = valueCounts(0)._2
- // targetCount: target value for `currentCount`. If `currentCount` is closest value to
- // `targetCount`, then current value is a split threshold. After finding a split threshold,
- // `targetCount` is added by stride.
- var targetCount = stride
- while (index < valueCounts.length) {
- val previousCount = currentCount
- currentCount += valueCounts(index)._2
- val previousGap = math.abs(previousCount - targetCount)
- val currentGap = math.abs(currentCount - targetCount)
- // If adding count of current value to currentCount makes the gap between currentCount and
- // targetCount smaller, previous value is a split threshold.
- if (previousGap < currentGap) {
- splitsBuilder += valueCounts(index - 1)._1
- targetCount += stride
- }
- index += 1
- }
- splitsBuilder.result()
- }
- }
-
- /**
- * Adjust split candidates to proper splits by: adding positive/negative infinity to both sides as
- * needed, and adding a default split value of 0 if no good candidates are found.
- */
- private[feature] def getSplits(candidates: Array[Double]): Array[Double] = {
- val effectiveValues = if (candidates.nonEmpty) {
- if (candidates.head == Double.NegativeInfinity
- && candidates.last == Double.PositiveInfinity) {
- candidates.drop(1).dropRight(1)
- } else if (candidates.head == Double.NegativeInfinity) {
- candidates.drop(1)
- } else if (candidates.last == Double.PositiveInfinity) {
- candidates.dropRight(1)
- } else {
- candidates
- }
- } else {
- candidates
- }
-
- if (effectiveValues.isEmpty) {
- Array(Double.NegativeInfinity, 0, Double.PositiveInfinity)
- } else {
- Array(Double.NegativeInfinity) ++ effectiveValues ++ Array(Double.PositiveInfinity)
- }
- }
-
@Since("1.6.0")
override def load(path: String): QuantileDiscretizer = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 12a76dbbfb..3ac6c77669 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -29,7 +29,7 @@ import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.VectorUDT
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types._
/**
@@ -103,7 +103,8 @@ class RFormula(override val uid: String)
RFormulaParser.parse($(formula)).hasIntercept
}
- override def fit(dataset: DataFrame): RFormulaModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): RFormulaModel = {
require(isDefined(formula), "Formula must be defined first.")
val parsedFormula = RFormulaParser.parse($(formula))
val resolvedFormula = parsedFormula.resolve(dataset.schema)
@@ -204,7 +205,8 @@ class RFormulaModel private[feature](
private[ml] val pipelineModel: PipelineModel)
extends Model[RFormulaModel] with RFormulaBase with MLWritable {
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
checkCanTransform(dataset.schema)
transformLabel(pipelineModel.transform(dataset))
}
@@ -232,10 +234,10 @@ class RFormulaModel private[feature](
override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)"
- private def transformLabel(dataset: DataFrame): DataFrame = {
+ private def transformLabel(dataset: Dataset[_]): DataFrame = {
val labelName = resolvedFormula.label
if (hasLabelCol(dataset.schema)) {
- dataset
+ dataset.toDF
} else if (dataset.schema.exists(_.name == labelName)) {
dataset.schema(labelName).dataType match {
case _: NumericType | BooleanType =>
@@ -246,7 +248,7 @@ class RFormulaModel private[feature](
} else {
// Ignore the label field. This is a hack so that this transformer can also work on test
// datasets in a Pipeline.
- dataset
+ dataset.toDF
}
}
@@ -323,7 +325,7 @@ private class ColumnPruner(override val uid: String, val columnsToPrune: Set[Str
def this(columnsToPrune: Set[String]) =
this(Identifiable.randomUID("columnPruner"), columnsToPrune)
- override def transform(dataset: DataFrame): DataFrame = {
+ override def transform(dataset: Dataset[_]): DataFrame = {
val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_))
dataset.select(columnsToKeep.map(dataset.col): _*)
}
@@ -396,7 +398,7 @@ private class VectorAttributeRewriter(
def this(vectorCol: String, prefixesToRewrite: Map[String, String]) =
this(Identifiable.randomUID("vectorAttrRewriter"), vectorCol, prefixesToRewrite)
- override def transform(dataset: DataFrame): DataFrame = {
+ override def transform(dataset: Dataset[_]): DataFrame = {
val metadata = {
val group = AttributeGroup.fromStructField(dataset.schema(vectorCol))
val attrs = group.attributes.get.map { attr =>
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
index e0ca45b9a6..2002d15745 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
@@ -22,7 +22,7 @@ import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.util._
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
import org.apache.spark.sql.types.StructType
/**
@@ -63,13 +63,12 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor
private val tableIdentifier: String = "__THIS__"
- @Since("1.6.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
val tableName = Identifiable.randomUID(uid)
dataset.registerTempTable(tableName)
val realStatement = $(statement).replace(tableIdentifier, tableName)
- val outputDF = dataset.sqlContext.sql(realStatement)
- outputDF
+ dataset.sqlContext.sql(realStatement)
}
@Since("1.6.0")
@@ -78,8 +77,11 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor
val sqlContext = SQLContext.getOrCreate(sc)
val dummyRDD = sc.parallelize(Seq(Row.empty))
val dummyDF = sqlContext.createDataFrame(dummyRDD, schema)
- dummyDF.registerTempTable(tableIdentifier)
- val outputSchema = sqlContext.sql($(statement)).schema
+ val tableName = Identifiable.randomUID(uid)
+ val realStatement = $(statement).replace(tableIdentifier, tableName)
+ dummyDF.registerTempTable(tableName)
+ val outputSchema = sqlContext.sql(realStatement).schema
+ sqlContext.dropTempTable(tableName)
outputSchema
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 26ee8e1bf1..118a6e3e6a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -85,7 +85,8 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
/** @group setParam */
def setWithStd(value: Boolean): this.type = set(withStd, value)
- override def fit(dataset: DataFrame): StandardScalerModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): StandardScalerModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v }
val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd))
@@ -135,7 +136,8 @@ class StandardScalerModel private[ml] (
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean))
val scale = udf { scaler.transform _ }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
index 0a0e0b0960..b96bc48566 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{ArrayType, StringType, StructType}
@@ -125,7 +125,8 @@ class StopWordsRemover(override val uid: String)
setDefault(stopWords -> StopWords.English, caseSensitive -> false)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
val outputSchema = transformSchema(dataset.schema)
val t = if ($(caseSensitive)) {
val stopWordsSet = $(stopWords).toSet
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index faa0f6f407..7e0d374f02 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -26,7 +26,7 @@ import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashMap
@@ -80,7 +80,8 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def fit(dataset: DataFrame): StringIndexerModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): StringIndexerModel = {
val counts = dataset.select(col($(inputCol)).cast(StringType))
.rdd
.map(_.getString(0))
@@ -144,11 +145,12 @@ class StringIndexerModel (
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
if (!dataset.schema.fieldNames.contains($(inputCol))) {
logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " +
"Skip StringIndexerModel.")
- return dataset
+ return dataset.toDF
}
validateAndTransformSchema(dataset.schema)
@@ -286,7 +288,8 @@ class IndexToString private[ml] (override val uid: String)
StructType(outputFields)
}
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
val inputColSchema = dataset.schema($(inputCol))
// If the labels array is empty use column metadata
val values = if ($(labels).isEmpty) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 957e8e7a59..4d3e46e488 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
@@ -47,10 +47,11 @@ class VectorAssembler(override val uid: String)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
// Schema transformation.
val schema = dataset.schema
- lazy val first = dataset.first()
+ lazy val first = dataset.toDF.first()
val attrs = $(inputCols).flatMap { c =>
val field = schema(c)
val index = schema.fieldIndex(c)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index bf4aef2a74..68b699d569 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -31,7 +31,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT}
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.util.collection.OpenHashSet
@@ -108,7 +108,8 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def fit(dataset: DataFrame): VectorIndexerModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): VectorIndexerModel = {
transformSchema(dataset.schema, logging = true)
val firstRow = dataset.select($(inputCol)).take(1)
require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.")
@@ -345,7 +346,8 @@ class VectorIndexerModel private[ml] (
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val newField = prepOutputField(dataset.schema)
val transformUDF = udf { (vector: Vector) => transformFunc(vector) }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
index b60e82de00..7a9468b87b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg._
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StructType
@@ -89,7 +89,8 @@ final class VectorSlicer(override val uid: String)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
// Validity checks
transformSchema(dataset.schema)
val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol)))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 95bae1c8a3..a72692960f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT}
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, SQLContext}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
@@ -135,7 +135,8 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
/** @group setParam */
def setMinCount(value: Int): this.type = set(minCount, value)
- override def fit(dataset: DataFrame): Word2VecModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): Word2VecModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0))
val wordVectors = new feature.Word2Vec()
@@ -219,7 +220,8 @@ class Word2VecModel private[ml] (
* Transform a sentence column to a vector column to represent the whole sentence. The transform
* is performed by averaging all word vectors it contains.
*/
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val vectors = wordVectors.getVectors
.mapValues(vv => Vectors.dense(vv.map(_.toDouble)))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index d7837b6730..c368aadd23 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.param
import java.lang.reflect.Modifier
+import java.util.{List => JList}
import java.util.NoSuchElementException
import scala.annotation.varargs
@@ -833,6 +834,11 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
this
}
+ /** Put param pairs with a [[java.util.List]] of values for Python. */
+ private[ml] def put(paramPairs: JList[ParamPair[_]]): this.type = {
+ put(paramPairs.asScala: _*)
+ }
+
/**
* Optionally returns the value associated with a param.
*/
@@ -932,6 +938,11 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
}
}
+ /** Java-friendly method for Python API */
+ private[ml] def toList: java.util.List[ParamPair[_]] = {
+ this.toSeq.asJava
+ }
+
/**
* Number of param pairs in this map.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 3ce129b12c..1d03a5b4f4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -62,7 +62,7 @@ private[shared] object SharedParamsCodeGen {
"every 10 iterations", isValid = "(interval: Int) => interval == -1 || interval >= 1"),
ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " +
- "will filter out rows with bad values), or error (which will throw an errror). More " +
+ "will filter out rows with bad values), or error (which will throw an error). More " +
"options may be added later",
isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"),
ParamDesc[Boolean]("standardization", "whether to standardize the training features" +
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 96263c5baf..64d6af2766 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -270,10 +270,10 @@ private[ml] trait HasFitIntercept extends Params {
private[ml] trait HasHandleInvalid extends Params {
/**
- * Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.
+ * Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an error). More options may be added later.
* @group param
*/
- final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later", ParamValidators.inArray(Array("skip", "error")))
+ final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an error). More options may be added later", ParamValidators.inArray(Array("skip", "error")))
/** @group getParam */
final def getHandleInvalid: String = $(handleInvalid)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
index 40590e71c4..7835468626 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.ml.regression.{AFTSurvivalRegression, AFTSurvivalRegressionModel}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
private[r] class AFTSurvivalRegressionWrapper private (
pipeline: PipelineModel,
@@ -43,8 +43,8 @@ private[r] class AFTSurvivalRegressionWrapper private (
features ++ Array("Log(scale)")
}
- def transform(dataset: DataFrame): DataFrame = {
- pipeline.transform(dataset)
+ def transform(dataset: Dataset[_]): DataFrame = {
+ pipeline.transform(dataset).drop(aftModel.getFeaturesCol)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
new file mode 100644
index 0000000000..475a308385
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.regression._
+import org.apache.spark.sql._
+
+private[r] class GeneralizedLinearRegressionWrapper private (
+ pipeline: PipelineModel,
+ val features: Array[String]) {
+
+ private val glm: GeneralizedLinearRegressionModel =
+ pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel]
+
+ lazy val rCoefficients: Array[Double] = if (glm.getFitIntercept) {
+ Array(glm.intercept) ++ glm.coefficients.toArray
+ } else {
+ glm.coefficients.toArray
+ }
+
+ lazy val rFeatures: Array[String] = if (glm.getFitIntercept) {
+ Array("(Intercept)") ++ features
+ } else {
+ features
+ }
+
+ def transform(dataset: DataFrame): DataFrame = {
+ pipeline.transform(dataset).drop(glm.getFeaturesCol)
+ }
+}
+
+private[r] object GeneralizedLinearRegressionWrapper {
+
+ def fit(
+ formula: String,
+ data: DataFrame,
+ family: String,
+ link: String,
+ epsilon: Double,
+ maxit: Int): GeneralizedLinearRegressionWrapper = {
+ val rFormula = new RFormula()
+ .setFormula(formula)
+ val rFormulaModel = rFormula.fit(data)
+ // get labels and feature names from output schema
+ val schema = rFormulaModel.transform(data).schema
+ val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol))
+ .attributes.get
+ val features = featureAttrs.map(_.name.get)
+ // assemble and fit the pipeline
+ val glm = new GeneralizedLinearRegression()
+ .setFamily(family)
+ .setLink(link)
+ .setFitIntercept(rFormula.hasIntercept)
+ .setTol(epsilon)
+ .setMaxIter(maxit)
+ val pipeline = new Pipeline()
+ .setStages(Array(rFormulaModel, glm))
+ .fit(data)
+ new GeneralizedLinearRegressionWrapper(pipeline, features)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
new file mode 100644
index 0000000000..9e2b81ee20
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
+import org.apache.spark.ml.feature.VectorAssembler
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+private[r] class KMeansWrapper private (
+ pipeline: PipelineModel) {
+
+ private val kMeansModel: KMeansModel = pipeline.stages(1).asInstanceOf[KMeansModel]
+
+ lazy val coefficients: Array[Double] = kMeansModel.clusterCenters.flatMap(_.toArray)
+
+ private lazy val attrs = AttributeGroup.fromStructField(
+ kMeansModel.summary.predictions.schema(kMeansModel.getFeaturesCol))
+
+ lazy val features: Array[String] = attrs.attributes.get.map(_.name.get)
+
+ lazy val k: Int = kMeansModel.getK
+
+ lazy val size: Array[Long] = kMeansModel.summary.clusterSizes
+
+ lazy val cluster: DataFrame = kMeansModel.summary.cluster
+
+ def fitted(method: String): DataFrame = {
+ if (method == "centers") {
+ kMeansModel.summary.predictions.drop(kMeansModel.getFeaturesCol)
+ } else if (method == "classes") {
+ kMeansModel.summary.cluster
+ } else {
+ throw new UnsupportedOperationException(
+ s"Method (centers or classes) required but $method found.")
+ }
+ }
+
+ def transform(dataset: Dataset[_]): DataFrame = {
+ pipeline.transform(dataset).drop(kMeansModel.getFeaturesCol)
+ }
+
+}
+
+private[r] object KMeansWrapper {
+
+ def fit(
+ data: DataFrame,
+ k: Double,
+ maxIter: Double,
+ initMode: String,
+ columns: Array[String]): KMeansWrapper = {
+
+ val assembler = new VectorAssembler()
+ .setInputCols(columns)
+ .setOutputCol("features")
+
+ val kMeans = new KMeans()
+ .setK(k.toInt)
+ .setMaxIter(maxIter.toInt)
+ .setInitMode(initMode)
+
+ val pipeline = new Pipeline()
+ .setStages(Array(assembler, kMeans))
+ .fit(data)
+
+ new KMeansWrapper(pipeline)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
index 07383d393d..b17207e99b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
@@ -21,7 +21,7 @@ import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel}
import org.apache.spark.ml.feature.{IndexToString, RFormula}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
private[r] class NaiveBayesWrapper private (
pipeline: PipelineModel,
@@ -36,8 +36,10 @@ private[r] class NaiveBayesWrapper private (
lazy val tables: Array[Double] = naiveBayesModel.theta.toArray.map(math.exp)
- def transform(dataset: DataFrame): DataFrame = {
- pipeline.transform(dataset).drop(PREDICTED_LABEL_INDEX_COL)
+ def transform(dataset: Dataset[_]): DataFrame = {
+ pipeline.transform(dataset)
+ .drop(PREDICTED_LABEL_INDEX_COL)
+ .drop(naiveBayesModel.getFeaturesCol)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
deleted file mode 100644
index d23e4fc9d1..0000000000
--- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
+++ /dev/null
@@ -1,167 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ml.api.r
-
-import org.apache.spark.ml.{Pipeline, PipelineModel}
-import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
-import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
-import org.apache.spark.ml.feature.{RFormula, VectorAssembler}
-import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
-import org.apache.spark.sql.DataFrame
-
-private[r] object SparkRWrappers {
- def fitRModelFormula(
- value: String,
- df: DataFrame,
- family: String,
- lambda: Double,
- alpha: Double,
- standardize: Boolean,
- solver: String): PipelineModel = {
- val formula = new RFormula().setFormula(value)
- val estimator = family match {
- case "gaussian" => new LinearRegression()
- .setRegParam(lambda)
- .setElasticNetParam(alpha)
- .setFitIntercept(formula.hasIntercept)
- .setStandardization(standardize)
- .setSolver(solver)
- case "binomial" => new LogisticRegression()
- .setRegParam(lambda)
- .setElasticNetParam(alpha)
- .setFitIntercept(formula.hasIntercept)
- .setStandardization(standardize)
- }
- val pipeline = new Pipeline().setStages(Array(formula, estimator))
- pipeline.fit(df)
- }
-
- def fitKMeans(
- df: DataFrame,
- initMode: String,
- maxIter: Double,
- k: Double,
- columns: Array[String]): PipelineModel = {
- val assembler = new VectorAssembler().setInputCols(columns)
- val kMeans = new KMeans()
- .setInitMode(initMode)
- .setMaxIter(maxIter.toInt)
- .setK(k.toInt)
- .setFeaturesCol(assembler.getOutputCol)
- val pipeline = new Pipeline().setStages(Array(assembler, kMeans))
- pipeline.fit(df)
- }
-
- def getModelCoefficients(model: PipelineModel): Array[Double] = {
- model.stages.last match {
- case m: LinearRegressionModel => {
- val coefficientStandardErrorsR = Array(m.summary.coefficientStandardErrors.last) ++
- m.summary.coefficientStandardErrors.dropRight(1)
- val tValuesR = Array(m.summary.tValues.last) ++ m.summary.tValues.dropRight(1)
- val pValuesR = Array(m.summary.pValues.last) ++ m.summary.pValues.dropRight(1)
- if (m.getFitIntercept) {
- Array(m.intercept) ++ m.coefficients.toArray ++ coefficientStandardErrorsR ++
- tValuesR ++ pValuesR
- } else {
- m.coefficients.toArray ++ coefficientStandardErrorsR ++ tValuesR ++ pValuesR
- }
- }
- case m: LogisticRegressionModel => {
- if (m.getFitIntercept) {
- Array(m.intercept) ++ m.coefficients.toArray
- } else {
- m.coefficients.toArray
- }
- }
- case m: KMeansModel =>
- m.clusterCenters.flatMap(_.toArray)
- }
- }
-
- def getModelDevianceResiduals(model: PipelineModel): Array[Double] = {
- model.stages.last match {
- case m: LinearRegressionModel =>
- m.summary.devianceResiduals
- case m: LogisticRegressionModel =>
- throw new UnsupportedOperationException(
- "No deviance residuals available for LogisticRegressionModel")
- }
- }
-
- def getKMeansModelSize(model: PipelineModel): Array[Int] = {
- model.stages.last match {
- case m: KMeansModel => Array(m.getK) ++ m.summary.size
- case other => throw new UnsupportedOperationException(
- s"KMeansModel required but ${other.getClass.getSimpleName} found.")
- }
- }
-
- def getKMeansCluster(model: PipelineModel, method: String): DataFrame = {
- model.stages.last match {
- case m: KMeansModel =>
- if (method == "centers") {
- // Drop the assembled vector for easy-print to R side.
- m.summary.predictions.drop(m.summary.featuresCol)
- } else if (method == "classes") {
- m.summary.cluster
- } else {
- throw new UnsupportedOperationException(
- s"Method (centers or classes) required but $method found.")
- }
- case other => throw new UnsupportedOperationException(
- s"KMeansModel required but ${other.getClass.getSimpleName} found.")
- }
- }
-
- def getModelFeatures(model: PipelineModel): Array[String] = {
- model.stages.last match {
- case m: LinearRegressionModel =>
- val attrs = AttributeGroup.fromStructField(
- m.summary.predictions.schema(m.summary.featuresCol))
- if (m.getFitIntercept) {
- Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
- } else {
- attrs.attributes.get.map(_.name.get)
- }
- case m: LogisticRegressionModel =>
- val attrs = AttributeGroup.fromStructField(
- m.summary.predictions.schema(m.summary.featuresCol))
- if (m.getFitIntercept) {
- Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get)
- } else {
- attrs.attributes.get.map(_.name.get)
- }
- case m: KMeansModel =>
- val attrs = AttributeGroup.fromStructField(
- m.summary.predictions.schema(m.summary.featuresCol))
- attrs.attributes.get.map(_.name.get)
- }
- }
-
- def getModelName(model: PipelineModel): String = {
- model.stages.last match {
- case m: LinearRegressionModel =>
- "LinearRegressionModel"
- case m: LogisticRegressionModel =>
- "LogisticRegressionModel"
- case m: KMeansModel =>
- "KMeansModel"
- }
- }
-}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 4a3ad662a0..36dce01590 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -40,7 +40,7 @@ import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.CholeskyDecomposition
import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType}
import org.apache.spark.storage.StorageLevel
@@ -200,8 +200,8 @@ class ALSModel private[ml] (
@Since("1.3.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
- @Since("1.3.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
// Register a UDF for DataFrame, and then
// create a new column named map(predictionCol) by running the predict UDF.
val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
@@ -385,8 +385,8 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
this
}
- @Since("1.3.0")
- override def fit(dataset: DataFrame): ALSModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): ALSModel = {
import dataset.sqlContext.implicits._
val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f)
val ratings = dataset
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index ba5708ab8d..89ba6ab5d2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -31,8 +31,9 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT}
+import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructType}
import org.apache.spark.storage.StorageLevel
@@ -103,7 +104,7 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
if (fitting) {
SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType)
- SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+ SchemaUtils.checkNumericType(schema, $(labelCol))
}
if (hasQuantilesCol) {
SchemaUtils.appendColumn(schema, $(quantilesCol), new VectorUDT)
@@ -183,24 +184,35 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
* Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset,
* and put it in an RDD with strong types.
*/
- protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = {
- dataset.select($(featuresCol), $(labelCol), $(censorCol)).rdd.map {
- case Row(features: Vector, label: Double, censor: Double) =>
- AFTPoint(features, label, censor)
- }
+ protected[ml] def extractAFTPoints(dataset: Dataset[_]): RDD[AFTPoint] = {
+ dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), col($(censorCol)))
+ .rdd.map {
+ case Row(features: Vector, label: Double, censor: Double) =>
+ AFTPoint(features, label, censor)
+ }
}
- @Since("1.6.0")
- override def fit(dataset: DataFrame): AFTSurvivalRegressionModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = {
validateAndTransformSchema(dataset.schema, fitting = true)
val instances = extractAFTPoints(dataset)
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
- val costFun = new AFTCostFun(instances, $(fitIntercept))
+ val featuresSummarizer = {
+ val seqOp = (c: MultivariateOnlineSummarizer, v: AFTPoint) => c.add(v.features)
+ val combOp = (c1: MultivariateOnlineSummarizer, c2: MultivariateOnlineSummarizer) => {
+ c1.merge(c2)
+ }
+ instances.treeAggregate(new MultivariateOnlineSummarizer)(seqOp, combOp)
+ }
+
+ val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
+
+ val costFun = new AFTCostFun(instances, $(fitIntercept), featuresStd)
val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
- val numFeatures = dataset.select($(featuresCol)).take(1)(0).getAs[Vector](0).size
+ val numFeatures = featuresStd.size
/*
The parameters vector has three parts:
the first element: Double, log(sigma), the log of scale parameter
@@ -229,7 +241,13 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
if (handlePersistence) instances.unpersist()
- val coefficients = Vectors.dense(parameters.slice(2, parameters.length))
+ val rawCoefficients = parameters.slice(2, parameters.length)
+ var i = 0
+ while (i < numFeatures) {
+ rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 }
+ i += 1
+ }
+ val coefficients = Vectors.dense(rawCoefficients)
val intercept = parameters(1)
val scale = math.exp(parameters(0))
val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale)
@@ -298,8 +316,8 @@ class AFTSurvivalRegressionModel private[ml] (
math.exp(BLAS.dot(coefficients, features) + intercept)
}
- @Since("1.6.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema)
val predictUDF = udf { features: Vector => predict(features) }
val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)}
@@ -433,29 +451,36 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel]
* @param parameters including three part: The log of scale parameter, the intercept and
* regression coefficients corresponding to the features.
* @param fitIntercept Whether to fit an intercept term.
+ * @param featuresStd The standard deviation values of the features.
*/
-private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
- extends Serializable {
+private class AFTAggregator(
+ parameters: BDV[Double],
+ fitIntercept: Boolean,
+ featuresStd: Array[Double]) extends Serializable {
// the regression coefficients to the covariates
private val coefficients = parameters.slice(2, parameters.length)
- private val intercept = parameters.valueAt(1)
+ private val intercept = parameters(1)
// sigma is the scale parameter of the AFT model
private val sigma = math.exp(parameters(0))
private var totalCnt: Long = 0L
private var lossSum = 0.0
- private var gradientCoefficientSum = BDV.zeros[Double](coefficients.length)
- private var gradientInterceptSum = 0.0
- private var gradientLogSigmaSum = 0.0
+ // Here we optimize loss function over log(sigma), intercept and coefficients
+ private val gradientSumArray = Array.ofDim[Double](parameters.length)
def count: Long = totalCnt
+ def loss: Double = {
+ require(totalCnt > 0.0, s"The number of instances should be " +
+ s"greater than 0.0, but got $totalCnt.")
+ lossSum / totalCnt
+ }
+ def gradient: BDV[Double] = {
+ require(totalCnt > 0.0, s"The number of instances should be " +
+ s"greater than 0.0, but got $totalCnt.")
+ new BDV(gradientSumArray.map(_ / totalCnt.toDouble))
+ }
- def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt
-
- // Here we optimize loss function over coefficients, intercept and log(sigma)
- def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)),
- BDV(Array(gradientInterceptSum/totalCnt.toDouble)), gradientCoefficientSum/totalCnt.toDouble)
/**
* Add a new training data to this AFTAggregator, and update the loss and gradient
@@ -465,25 +490,32 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
* @return This AFTAggregator object.
*/
def add(data: AFTPoint): this.type = {
-
- val interceptFlag = if (fitIntercept) 1.0 else 0.0
-
- val xi = data.features.toBreeze
+ val xi = data.features
val ti = data.label
val delta = data.censor
- val epsilon = (math.log(ti) - coefficients.dot(xi) - intercept * interceptFlag ) / sigma
- lossSum += math.log(sigma) * delta
- lossSum += (math.exp(epsilon) - delta * epsilon)
+ val margin = {
+ var sum = 0.0
+ xi.foreachActive { (index, value) =>
+ if (featuresStd(index) != 0.0 && value != 0.0) {
+ sum += coefficients(index) * (value / featuresStd(index))
+ }
+ }
+ sum + intercept
+ }
+ val epsilon = (math.log(ti) - margin) / sigma
+
+ lossSum += delta * math.log(sigma) - delta * epsilon + math.exp(epsilon)
- // Sanity check (should never occur):
- assert(!lossSum.isInfinity,
- s"AFTAggregator loss sum is infinity. Error for unknown reason.")
+ val multiplier = (delta - math.exp(epsilon)) / sigma
- val deltaMinusExpEps = delta - math.exp(epsilon)
- gradientCoefficientSum += xi * deltaMinusExpEps / sigma
- gradientInterceptSum += interceptFlag * deltaMinusExpEps / sigma
- gradientLogSigmaSum += delta + deltaMinusExpEps * epsilon
+ gradientSumArray(0) += delta + multiplier * sigma * epsilon
+ gradientSumArray(1) += { if (fitIntercept) multiplier else 0.0 }
+ xi.foreachActive { (index, value) =>
+ if (featuresStd(index) != 0.0 && value != 0.0) {
+ gradientSumArray(index + 2) += multiplier * (value / featuresStd(index))
+ }
+ }
totalCnt += 1
this
@@ -502,9 +534,12 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
totalCnt += other.totalCnt
lossSum += other.lossSum
- gradientCoefficientSum += other.gradientCoefficientSum
- gradientInterceptSum += other.gradientInterceptSum
- gradientLogSigmaSum += other.gradientLogSigmaSum
+ var i = 0
+ val len = this.gradientSumArray.length
+ while (i < len) {
+ this.gradientSumArray(i) += other.gradientSumArray(i)
+ i += 1
+ }
}
this
}
@@ -515,12 +550,15 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
* It returns the loss and gradient at a particular point (parameters).
* It's used in Breeze's convex optimization routines.
*/
-private class AFTCostFun(data: RDD[AFTPoint], fitIntercept: Boolean)
- extends DiffFunction[BDV[Double]] {
+private class AFTCostFun(
+ data: RDD[AFTPoint],
+ fitIntercept: Boolean,
+ featuresStd: Array[Double]) extends DiffFunction[BDV[Double]] {
override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = {
- val aftAggregator = data.treeAggregate(new AFTAggregator(parameters, fitIntercept))(
+ val aftAggregator = data.treeAggregate(
+ new AFTAggregator(parameters, fitIntercept, featuresStd))(
seqOp = (c, v) => (c, v) match {
case (aggregator, instance) => aggregator.add(instance)
},
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 50ac96eb5e..c04c416aaf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -33,7 +33,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
@@ -83,7 +83,7 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val
/** @group setParam */
def setVarianceCol(value: String): this.type = set(varianceCol, value)
- override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = {
+ override protected def train(dataset: Dataset[_]): DecisionTreeRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
@@ -158,15 +158,16 @@ final class DecisionTreeRegressionModel private[ml] (
rootNode.predictImpl(features).impurityStats.calculate()
}
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
transformImpl(dataset)
}
- override protected def transformImpl(dataset: DataFrame): DataFrame = {
+ override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val predictUDF = udf { (features: Vector) => predict(features) }
val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) }
- var output = dataset
+ var output = dataset.toDF
if ($(predictionCol).nonEmpty) {
output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
@@ -203,9 +204,9 @@ final class DecisionTreeRegressionModel private[ml] (
* to determine feature importance instead.
*/
@Since("2.0.0")
- lazy val featureImportances: Vector = RandomForest.featureImportances(this, numFeatures)
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures)
- /** Convert to spark.mllib DecisionTreeModel (losing some infomation) */
+ /** Convert to spark.mllib DecisionTreeModel (losing some information) */
override private[spark] def toOld: OldDecisionTreeModel = {
new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index da5b77e8fa..741724d7a1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -18,23 +18,23 @@
package org.apache.spark.ml.regression
import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{PredictionModel, Predictor}
-import org.apache.spark.ml.param.{Param, ParamMap}
-import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel,
- TreeRegressorParams}
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
-import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss,
- SquaredError => OldSquaredError}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
/**
@@ -42,12 +42,24 @@ import org.apache.spark.sql.functions._
* [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]]
* learning algorithm for regression.
* It supports both continuous and categorical features.
+ *
+ * The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999.
+ *
+ * Notes on Gradient Boosting vs. TreeBoost:
+ * - This implementation is for Stochastic Gradient Boosting, not for TreeBoost.
+ * - Both algorithms learn tree ensembles by minimizing loss functions.
+ * - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes
+ * based on the loss function, whereas the original gradient boosting method does not.
+ * - When the loss is SquaredError, these methods give the same result, but they could differ
+ * for other loss functions.
+ * - We expect to implement TreeBoost in the future:
+ * [https://issues.apache.org/jira/browse/SPARK-4240]
*/
@Since("1.4.0")
@Experimental
final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, GBTRegressor, GBTRegressionModel]
- with GBTParams with TreeRegressorParams with Logging {
+ with GBTRegressorParams with DefaultParamsWritable with Logging {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("gbtr"))
@@ -101,42 +113,13 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
@Since("1.4.0")
override def setStepSize(value: Double): this.type = super.setStepSize(value)
- // Parameters for GBTRegressor:
-
- /**
- * Loss function which GBT tries to minimize. (case-insensitive)
- * Supported: "squared" (L2) and "absolute" (L1)
- * (default = squared)
- * @group param
- */
- @Since("1.4.0")
- val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
- " tries to minimize (case-insensitive). Supported options:" +
- s" ${GBTRegressor.supportedLossTypes.mkString(", ")}",
- (value: String) => GBTRegressor.supportedLossTypes.contains(value.toLowerCase))
-
- setDefault(lossType -> "squared")
+ // Parameters from GBTRegressorParams:
/** @group setParam */
@Since("1.4.0")
def setLossType(value: String): this.type = set(lossType, value)
- /** @group getParam */
- @Since("1.4.0")
- def getLossType: String = $(lossType).toLowerCase
-
- /** (private[ml]) Convert new loss to old loss. */
- override private[ml] def getOldLossType: OldLoss = {
- getLossType match {
- case "squared" => OldSquaredError
- case "absolute" => OldAbsoluteError
- case _ =>
- // Should never happen because of check in setter method.
- throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType")
- }
- }
-
- override protected def train(dataset: DataFrame): GBTRegressionModel = {
+ override protected def train(dataset: Dataset[_]): GBTRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
@@ -153,11 +136,14 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
@Since("1.4.0")
@Experimental
-object GBTRegressor {
- // The losses below should be lowercase.
+object GBTRegressor extends DefaultParamsReadable[GBTRegressor] {
+
/** Accessor for supported loss settings: squared (L2), absolute (L1) */
@Since("1.4.0")
- final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase)
+ final val supportedLossTypes: Array[String] = GBTRegressorParams.supportedLossTypes
+
+ @Since("2.0.0")
+ override def load(path: String): GBTRegressor = super.load(path)
}
/**
@@ -177,9 +163,10 @@ final class GBTRegressionModel private[ml](
private val _treeWeights: Array[Double],
override val numFeatures: Int)
extends PredictionModel[Vector, GBTRegressionModel]
- with TreeEnsembleModel with Serializable {
+ with GBTRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel]
+ with MLWritable with Serializable {
- require(numTrees > 0, "GBTRegressionModel requires at least 1 tree.")
+ require(_trees.nonEmpty, "GBTRegressionModel requires at least 1 tree.")
require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" +
s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
@@ -193,12 +180,12 @@ final class GBTRegressionModel private[ml](
this(uid, _trees, _treeWeights, -1)
@Since("1.4.0")
- override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+ override def trees: Array[DecisionTreeRegressionModel] = _trees
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
- override protected def transformImpl(dataset: DataFrame): DataFrame = {
+ override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
@@ -213,6 +200,9 @@ final class GBTRegressionModel private[ml](
blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
}
+ /** Number of trees in ensemble */
+ val numTrees: Int = trees.length
+
@Since("1.4.0")
override def copy(extra: ParamMap): GBTRegressionModel = {
copyValues(new GBTRegressionModel(uid, _trees, _treeWeights, numFeatures),
@@ -224,16 +214,81 @@ final class GBTRegressionModel private[ml](
s"GBTRegressionModel (uid=$uid) with $numTrees trees"
}
+ /**
+ * Estimate of the importance of each feature.
+ *
+ * Each feature's importance is the average of its importance across all trees in the ensemble
+ * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ * and follows the implementation from scikit-learn.
+ *
+ * @see [[DecisionTreeRegressionModel.featureImportances]]
+ */
+ @Since("2.0.0")
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
+
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldGBTModel = {
new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights)
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new GBTRegressionModel.GBTRegressionModelWriter(this)
}
-private[ml] object GBTRegressionModel {
+@Since("2.0.0")
+object GBTRegressionModel extends MLReadable[GBTRegressionModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[GBTRegressionModel] = new GBTRegressionModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): GBTRegressionModel = super.load(path)
+
+ private[GBTRegressionModel]
+ class GBTRegressionModelWriter(instance: GBTRegressionModel) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures,
+ "numTrees" -> instance.getNumTrees)
+ EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
+ }
+ }
+
+ private class GBTRegressionModelReader extends MLReader[GBTRegressionModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[GBTRegressionModel].getName
+ private val treeClassName = classOf[DecisionTreeRegressionModel].getName
+
+ override def load(path: String): GBTRegressionModel = {
+ implicit val format = DefaultFormats
+ val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
+ EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
+
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+
+ val trees: Array[DecisionTreeRegressionModel] = treesData.map {
+ case (treeMetadata, root) =>
+ val tree =
+ new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
+ DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+ tree
+ }
+
+ require(numTrees == trees.length, s"GBTRegressionModel.load expected $numTrees" +
+ s" trees based on metadata but found ${trees.length} trees.")
+
+ val model = new GBTRegressionModel(metadata.uid, trees, treeWeights, numFeatures)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
- /** (private[ml]) Convert a model from the old API */
- def fromOld(
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
oldModel: OldGBTModel,
parent: GBTRegressor,
categoricalFeatures: Map[Int, Int],
@@ -245,6 +300,6 @@ private[ml] object GBTRegressionModel {
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr")
- new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures)
+ new GBTRegressionModel(uid, newTrees, oldModel.treeWeights, numFeatures)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 0e71e8d8e1..e92a3e7fa1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -31,9 +31,9 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{BLAS, Vector}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
/**
* Params for Generalized Linear Regression.
@@ -47,6 +47,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
* to be used in the model.
* Supported options: "gaussian", "binomial", "poisson" and "gamma".
* Default is "gaussian".
+ *
* @group param
*/
@Since("2.0.0")
@@ -63,6 +64,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
* Param for the name of link function which provides the relationship
* between the linear predictor and the mean of the distribution function.
* Supported options: "identity", "log", "inverse", "logit", "probit", "cloglog" and "sqrt".
+ *
* @group param
*/
@Since("2.0.0")
@@ -163,7 +165,11 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
setDefault(tol -> 1E-6)
/**
- * Sets the regularization parameter.
+ * Sets the regularization parameter for L2 regularization.
+ * The regularization term is
+ * {{{
+ * 0.5 * regParam * L2norm(coefficients)^2
+ * }}}
* Default is 0.0.
* @group setParam
*/
@@ -190,7 +196,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
def setSolver(value: String): this.type = set(solver, value)
setDefault(solver -> "irls")
- override protected def train(dataset: DataFrame): GeneralizedLinearRegressionModel = {
+ override protected def train(dataset: Dataset[_]): GeneralizedLinearRegressionModel = {
val familyObj = Family.fromName($(family))
val linkObj = if (isDefined(link)) {
Link.fromName($(link))
@@ -210,9 +216,10 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
}
val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
- val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd
- .map { case Row(label: Double, weight: Double, features: Vector) =>
- Instance(label, weight, features)
+ val instances: RDD[Instance] =
+ dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+ case Row(label: Double, weight: Double, features: Vector) =>
+ Instance(label, weight, features)
}
if (familyObj == Gaussian && linkObj == Identity) {
@@ -230,7 +237,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
predictionColName,
model,
wlsModel.diagInvAtWA.toArray,
- 1)
+ 1,
+ getSolver)
return model.setSummary(trainingSummary)
}
@@ -250,7 +258,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
predictionColName,
model,
irlsModel.diagInvAtWA.toArray,
- irlsModel.numIterations)
+ irlsModel.numIterations,
+ getSolver)
model.setSummary(trainingSummary)
}
@@ -698,7 +707,7 @@ class GeneralizedLinearRegressionModel private[ml] (
: (GeneralizedLinearRegressionModel, String) = {
$(predictionCol) match {
case "" =>
- val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString()
+ val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString
(copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName)
case p => (this, p)
}
@@ -769,11 +778,12 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr
* :: Experimental ::
* Summarizing Generalized Linear regression Fits.
*
- * @param predictions predictions outputted by the model's `transform` method
+ * @param predictions predictions output by the model's `transform` method
* @param predictionCol field in "predictions" which gives the prediction value of each instance
* @param model the model that should be summarized
* @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration
* @param numIterations number of iterations
+ * @param solver the solver algorithm used for model training
*/
@Since("2.0.0")
@Experimental
@@ -782,7 +792,8 @@ class GeneralizedLinearRegressionSummary private[regression] (
@Since("2.0.0") val predictionCol: String,
@Since("2.0.0") val model: GeneralizedLinearRegressionModel,
private val diagInvAtWA: Array[Double],
- @Since("2.0.0") val numIterations: Int) extends Serializable {
+ @Since("2.0.0") val numIterations: Int,
+ @Since("2.0.0") val solver: String) extends Serializable {
import GeneralizedLinearRegression._
@@ -930,6 +941,9 @@ class GeneralizedLinearRegressionSummary private[regression] (
/**
* Standard error of estimated coefficients and intercept.
+ *
+ * If [[GeneralizedLinearRegression.fitIntercept]] is set to true,
+ * then the last element returned corresponds to the intercept.
*/
@Since("2.0.0")
lazy val coefficientStandardErrors: Array[Double] = {
@@ -938,6 +952,9 @@ class GeneralizedLinearRegressionSummary private[regression] (
/**
* T-statistic of estimated coefficients and intercept.
+ *
+ * If [[GeneralizedLinearRegression.fitIntercept]] is set to true,
+ * then the last element returned corresponds to the intercept.
*/
@Since("2.0.0")
lazy val tValues: Array[Double] = {
@@ -951,6 +968,9 @@ class GeneralizedLinearRegressionSummary private[regression] (
/**
* Two-sided p-value of estimated coefficients and intercept.
+ *
+ * If [[GeneralizedLinearRegression.fitIntercept]] is set to true,
+ * then the last element returned corresponds to the intercept.
*/
@Since("2.0.0")
lazy val pValues: Array[Double] = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index fb733f9a34..7a78ecbdf1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -30,7 +30,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression}
import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, lit, udf}
import org.apache.spark.sql.types.{DoubleType, StructType}
import org.apache.spark.storage.StorageLevel
@@ -77,7 +77,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
* Extracts (label, feature, weight) from input dataset.
*/
protected[ml] def extractWeightedLabeledPoints(
- dataset: DataFrame): RDD[(Double, Double, Double)] = {
+ dataset: Dataset[_]): RDD[(Double, Double, Double)] = {
val f = if (dataset.schema($(featuresCol)).dataType.isInstanceOf[VectorUDT]) {
val idx = $(featureIndex)
val extract = udf { v: Vector => v(idx) }
@@ -90,7 +90,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
} else {
lit(1.0)
}
- dataset.select(col($(labelCol)), f, w).rdd.map {
+ dataset.select(col($(labelCol)).cast(DoubleType), f, w).rdd.map {
case Row(label: Double, feature: Double, weight: Double) =>
(label, feature, weight)
}
@@ -106,7 +106,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
schema: StructType,
fitting: Boolean): StructType = {
if (fitting) {
- SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+ SchemaUtils.checkNumericType(schema, $(labelCol))
if (hasWeightCol) {
SchemaUtils.checkColumnType(schema, $(weightCol), DoubleType)
} else {
@@ -164,8 +164,8 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
@Since("1.5.0")
override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra)
- @Since("1.5.0")
- override def fit(dataset: DataFrame): IsotonicRegressionModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): IsotonicRegressionModel = {
validateAndTransformSchema(dataset.schema, fitting = true)
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
val instances = extractWeightedLabeledPoints(dataset)
@@ -236,8 +236,8 @@ class IsotonicRegressionModel private[ml] (
copyValues(new IsotonicRegressionModel(uid, oldModel), extra).setParent(parent)
}
- @Since("1.5.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
val predict = dataset.schema($(featuresCol)).dataType match {
case DoubleType =>
udf { feature: Double => oldModel.predict(feature) }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index b81c588e44..71e02730c7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -38,8 +38,9 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.DoubleType
import org.apache.spark.storage.StorageLevel
/**
@@ -57,7 +58,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams
* The specific squared error loss function used is:
* L = 1/2n ||A coefficients - y||^2^
*
- * This support multiple types of regularization:
+ * This supports multiple types of regularization:
* - none (a.k.a. ordinary least squares)
* - L2 (ridge regression)
* - L1 (Lasso)
@@ -157,7 +158,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
def setSolver(value: String): this.type = set(solver, value)
setDefault(solver -> "auto")
- override protected def train(dataset: DataFrame): LinearRegressionModel = {
+ override protected def train(dataset: Dataset[_]): LinearRegressionModel = {
// Extract the number of features before deciding optimization solver.
val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd.map {
case Row(features: Vector) => features.size
@@ -171,7 +172,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
// For low dimensional data, WeightedLeastSquares is more efficiently since the
// training algorithm only requires one pass through the data. (SPARK-10668)
val instances: RDD[Instance] = dataset.select(
- col($(labelCol)), w, col($(featuresCol))).rdd.map {
+ col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
@@ -189,9 +190,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
+ $(featuresCol),
summaryModel,
model.diagInvAtWA.toArray,
- $(featuresCol),
Array(0D))
return lrModel.setSummary(trainingSummary)
@@ -248,9 +249,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
+ $(featuresCol),
model,
Array(0D),
- $(featuresCol),
Array(0D))
return copyValues(model.setSummary(trainingSummary))
} else {
@@ -355,9 +356,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
+ $(featuresCol),
model,
Array(0D),
- $(featuresCol),
objectiveHistory)
model.setSummary(trainingSummary)
}
@@ -412,15 +413,15 @@ class LinearRegressionModel private[ml] (
def hasSummary: Boolean = trainingSummary.isDefined
/**
- * Evaluates the model on a testset.
+ * Evaluates the model on a test dataset.
* @param dataset Test dataset to evaluate model on.
*/
- // TODO: decide on a good name before exposing to public API
- private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = {
+ @Since("2.0.0")
+ def evaluate(dataset: Dataset[_]): LinearRegressionSummary = {
// Handle possible missing or invalid prediction columns
val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol()
new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName,
- $(labelCol), this, Array(0D))
+ $(labelCol), $(featuresCol), summaryModel, Array(0D))
}
/**
@@ -431,7 +432,7 @@ class LinearRegressionModel private[ml] (
private[regression] def findSummaryModelAndPredictionCol(): (LinearRegressionModel, String) = {
$(predictionCol) match {
case "" =>
- val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString()
+ val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString
(copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName)
case p => (this, p)
}
@@ -510,9 +511,9 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] {
/**
* :: Experimental ::
* Linear regression training results. Currently, the training summary ignores the
- * training coefficients except for the objective trace.
+ * training weights except for the objective trace.
*
- * @param predictions predictions outputted by the model's `transform` method.
+ * @param predictions predictions output by the model's `transform` method.
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
*/
@Since("1.5.0")
@@ -521,13 +522,24 @@ class LinearRegressionTrainingSummary private[regression] (
predictions: DataFrame,
predictionCol: String,
labelCol: String,
+ featuresCol: String,
model: LinearRegressionModel,
diagInvAtWA: Array[Double],
- val featuresCol: String,
val objectiveHistory: Array[Double])
- extends LinearRegressionSummary(predictions, predictionCol, labelCol, model, diagInvAtWA) {
+ extends LinearRegressionSummary(
+ predictions,
+ predictionCol,
+ labelCol,
+ featuresCol,
+ model,
+ diagInvAtWA) {
- /** Number of training iterations until termination */
+ /**
+ * Number of training iterations until termination
+ *
+ * This value is only available when using the "l-bfgs" solver.
+ * @see [[LinearRegression.solver]]
+ */
@Since("1.5.0")
val totalIterations = objectiveHistory.length
@@ -537,7 +549,11 @@ class LinearRegressionTrainingSummary private[regression] (
* :: Experimental ::
* Linear regression results evaluated on a dataset.
*
- * @param predictions predictions outputted by the model's `transform` method.
+ * @param predictions predictions output by the model's `transform` method.
+ * @param predictionCol Field in "predictions" which gives the predicted value of the label at
+ * each instance.
+ * @param labelCol Field in "predictions" which gives the true label of each instance.
+ * @param featuresCol Field in "predictions" which gives the features of each instance as a vector.
*/
@Since("1.5.0")
@Experimental
@@ -545,12 +561,13 @@ class LinearRegressionSummary private[regression] (
@transient val predictions: DataFrame,
val predictionCol: String,
val labelCol: String,
+ val featuresCol: String,
val model: LinearRegressionModel,
private val diagInvAtWA: Array[Double]) extends Serializable {
@transient private val metrics = new RegressionMetrics(
predictions
- .select(predictionCol, labelCol)
+ .select(col(predictionCol), col(labelCol).cast(DoubleType))
.rdd
.map { case Row(pred: Double, label: Double) => (pred, label) },
!model.getFitIntercept)
@@ -638,6 +655,12 @@ class LinearRegressionSummary private[regression] (
/**
* Standard error of estimated coefficients and intercept.
+ * This value is only available when using the "normal" solver.
+ *
+ * If [[LinearRegression.fitIntercept]] is set to true,
+ * then the last element returned corresponds to the intercept.
+ *
+ * @see [[LinearRegression.solver]]
*/
lazy val coefficientStandardErrors: Array[Double] = {
if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) {
@@ -653,12 +676,18 @@ class LinearRegressionSummary private[regression] (
col(model.getWeightCol)).as("wse")).agg(sum(col("wse"))).first().getDouble(0)
}
val sigma2 = rss / degreesOfFreedom
- diagInvAtWA.map(_ * sigma2).map(math.sqrt(_))
+ diagInvAtWA.map(_ * sigma2).map(math.sqrt)
}
}
/**
* T-statistic of estimated coefficients and intercept.
+ * This value is only available when using the "normal" solver.
+ *
+ * If [[LinearRegression.fitIntercept]] is set to true,
+ * then the last element returned corresponds to the intercept.
+ *
+ * @see [[LinearRegression.solver]]
*/
lazy val tValues: Array[Double] = {
if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) {
@@ -676,6 +705,12 @@ class LinearRegressionSummary private[regression] (
/**
* Two-sided p-value of estimated coefficients and intercept.
+ * This value is only available when using the "normal" solver.
+ *
+ * If [[LinearRegression.fitIntercept]] is set to true,
+ * then the last element returned corresponds to the intercept.
+ *
+ * @see [[LinearRegression.solver]]
*/
lazy val pValues: Array[Double] = {
if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) {
@@ -826,7 +861,7 @@ private class LeastSquaresAggregator(
instance match { case Instance(label, weight, features) =>
require(dim == features.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $dim but got ${features.size}.")
- require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0")
+ require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
if (weight == 0.0) return this
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 798947b94a..4c4ff278d4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -17,18 +17,22 @@
package org.apache.spark.ml.regression
+import org.json4s.{DefaultFormats, JObject}
+import org.json4s.JsonDSL._
+
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams}
+import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.RandomForest
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
@@ -41,7 +45,7 @@ import org.apache.spark.sql.functions._
@Experimental
final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel]
- with RandomForestParams with TreeRegressorParams {
+ with RandomForestRegressorParams with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("rfr"))
@@ -89,7 +93,7 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val
override def setFeatureSubsetStrategy(value: String): this.type =
super.setFeatureSubsetStrategy(value)
- override protected def train(dataset: DataFrame): RandomForestRegressionModel = {
+ override protected def train(dataset: Dataset[_]): RandomForestRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
@@ -108,7 +112,7 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val
@Since("1.4.0")
@Experimental
-object RandomForestRegressor {
+object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor]{
/** Accessor for supported impurity settings: variance */
@Since("1.4.0")
final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
@@ -117,12 +121,17 @@ object RandomForestRegressor {
@Since("1.4.0")
final val supportedFeatureSubsetStrategies: Array[String] =
RandomForestParams.supportedFeatureSubsetStrategies
+
+ @Since("2.0.0")
+ override def load(path: String): RandomForestRegressor = super.load(path)
+
}
/**
* :: Experimental ::
* [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression.
* It supports both continuous and categorical features.
+ *
* @param _trees Decision trees in the ensemble.
* @param numFeatures Number of features used by this model
*/
@@ -133,27 +142,29 @@ final class RandomForestRegressionModel private[ml] (
private val _trees: Array[DecisionTreeRegressionModel],
override val numFeatures: Int)
extends PredictionModel[Vector, RandomForestRegressionModel]
- with TreeEnsembleModel with Serializable {
+ with RandomForestRegressionModelParams with TreeEnsembleModel[DecisionTreeRegressionModel]
+ with MLWritable with Serializable {
- require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.")
+ require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.")
/**
* Construct a random forest regression model, with all trees weighted equally.
+ *
* @param trees Component trees
*/
private[ml] def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) =
this(Identifiable.randomUID("rfr"), trees, numFeatures)
@Since("1.4.0")
- override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+ override def trees: Array[DecisionTreeRegressionModel] = _trees
// Note: We may add support for weights (based on tree performance) later on.
- private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0)
+ private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0)
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
- override protected def transformImpl(dataset: DataFrame): DataFrame = {
+ override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
val predictUDF = udf { (features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
@@ -165,9 +176,17 @@ final class RandomForestRegressionModel private[ml] (
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
// Predict average of tree predictions.
// Ignore the weights since all are 1.0 for now.
- _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees
+ _trees.map(_.rootNode.predictImpl(features).prediction).sum / getNumTrees
}
+ /**
+ * Number of trees in ensemble
+ * @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0
+ */
+ // TODO: Once this is removed, then this class can inherit from RandomForestRegressorParams
+ @deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0")
+ val numTrees: Int = trees.length
+
@Since("1.4.0")
override def copy(extra: ParamMap): RandomForestRegressionModel = {
copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent)
@@ -175,36 +194,83 @@ final class RandomForestRegressionModel private[ml] (
@Since("1.4.0")
override def toString: String = {
- s"RandomForestRegressionModel (uid=$uid) with $numTrees trees"
+ s"RandomForestRegressionModel (uid=$uid) with $getNumTrees trees"
}
/**
* Estimate of the importance of each feature.
*
- * This generalizes the idea of "Gini" importance to other losses,
- * following the explanation of Gini importance from "Random Forests" documentation
- * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ * Each feature's importance is the average of its importance across all trees in the ensemble
+ * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ * and follows the implementation from scikit-learn.
*
- * This feature importance is calculated as follows:
- * - Average over trees:
- * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
- * where gain is scaled by the number of instances passing through node
- * - Normalize importances for tree to sum to 1.
- * - Normalize feature importance vector to sum to 1.
+ * @see [[DecisionTreeRegressionModel.featureImportances]]
*/
@Since("1.5.0")
- lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
+ lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldRandomForestModel = {
new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld))
}
+
+ @Since("2.0.0")
+ override def write: MLWriter =
+ new RandomForestRegressionModel.RandomForestRegressionModelWriter(this)
}
-private[ml] object RandomForestRegressionModel {
+@Since("2.0.0")
+object RandomForestRegressionModel extends MLReadable[RandomForestRegressionModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[RandomForestRegressionModel] = new RandomForestRegressionModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): RandomForestRegressionModel = super.load(path)
+
+ private[RandomForestRegressionModel]
+ class RandomForestRegressionModelWriter(instance: RandomForestRegressionModel)
+ extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val extraMetadata: JObject = Map(
+ "numFeatures" -> instance.numFeatures,
+ "numTrees" -> instance.getNumTrees)
+ EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata)
+ }
+ }
+
+ private class RandomForestRegressionModelReader extends MLReader[RandomForestRegressionModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[RandomForestRegressionModel].getName
+ private val treeClassName = classOf[DecisionTreeRegressionModel].getName
+
+ override def load(path: String): RandomForestRegressionModel = {
+ implicit val format = DefaultFormats
+ val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
+ EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName)
+ val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
+ val numTrees = (metadata.metadata \ "numTrees").extract[Int]
+
+ val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) =>
+ val tree =
+ new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
+ DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+ tree
+ }
+ require(numTrees == trees.length, s"RandomForestRegressionModel.load expected $numTrees" +
+ s" trees based on metadata but found ${trees.length} trees.")
+
+ val model = new RandomForestRegressionModel(metadata.uid, trees, numFeatures)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
- /** (private[ml]) Convert a model from the old API */
- def fromOld(
+ /** Convert a model from the old API */
+ private[ml] def fromOld(
oldModel: OldRandomForestModel,
parent: RandomForestRegressor,
categoricalFeatures: Map[Int, Int],
@@ -215,6 +281,7 @@ private[ml] object RandomForestRegressionModel {
// parent for each tree is null since there is no good way to set this.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
- new RandomForestRegressionModel(parent.uid, newTrees, numFeatures)
+ val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfr")
+ new RandomForestRegressionModel(uid, newTrees, numFeatures)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index 13a13f0a7e..2f1f2523fd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -19,23 +19,25 @@ package org.apache.spark.ml.source.libsvm
import java.io.IOException
+import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{NullWritable, Text}
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
import org.apache.spark.annotation.Since
-import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
+import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
-import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, JoinedRow}
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, HadoopFileLinesReader, PartitionedFile}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.SerializableConfiguration
-import org.apache.spark.util.collection.BitSet
private[libsvm] class LibSVMOutputWriter(
path: String,
@@ -110,13 +112,16 @@ class DefaultSource extends FileFormat with DataSourceRegister {
@Since("1.6.0")
override def shortName(): String = "libsvm"
+ override def toString: String = "LibSVM"
+
private def verifySchema(dataSchema: StructType): Unit = {
if (dataSchema.size != 2 ||
(!dataSchema(0).dataType.sameType(DataTypes.DoubleType)
|| !dataSchema(1).dataType.sameType(new VectorUDT()))) {
- throw new IOException(s"Illegal schema for libsvm data, schema=${dataSchema}")
+ throw new IOException(s"Illegal schema for libsvm data, schema=$dataSchema")
}
}
+
override def inferSchema(
sqlContext: SQLContext,
options: Map[String, String],
@@ -127,6 +132,32 @@ class DefaultSource extends FileFormat with DataSourceRegister {
StructField("features", new VectorUDT(), nullable = false) :: Nil))
}
+ override def prepareRead(
+ sqlContext: SQLContext,
+ options: Map[String, String],
+ files: Seq[FileStatus]): Map[String, String] = {
+ def computeNumFeatures(): Int = {
+ val dataFiles = files.filterNot(_.getPath.getName startsWith "_")
+ val path = if (dataFiles.length == 1) {
+ dataFiles.head.getPath.toUri.toString
+ } else if (dataFiles.isEmpty) {
+ throw new IOException("No input path specified for libsvm data")
+ } else {
+ throw new IOException("Multiple input paths are not supported for libsvm data.")
+ }
+
+ val sc = sqlContext.sparkContext
+ val parsed = MLUtils.parseLibSVMFile(sc, path, sc.defaultParallelism)
+ MLUtils.computeNumFeatures(parsed)
+ }
+
+ val numFeatures = options.get("numFeatures").filter(_.toInt > 0).getOrElse {
+ computeNumFeatures()
+ }
+
+ new CaseInsensitiveMap(options + ("numFeatures" -> numFeatures.toString))
+ }
+
override def prepareWrite(
sqlContext: SQLContext,
job: Job,
@@ -144,36 +175,51 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
}
- override def buildInternalScan(
+ override def buildReader(
sqlContext: SQLContext,
dataSchema: StructType,
- requiredColumns: Array[String],
- filters: Array[Filter],
- bucketSet: Option[BitSet],
- inputFiles: Seq[FileStatus],
- broadcastedConf: Broadcast[SerializableConfiguration],
- options: Map[String, String]): RDD[InternalRow] = {
- // TODO: This does not handle cases where column pruning has been performed.
-
+ partitionSchema: StructType,
+ requiredSchema: StructType,
+ filters: Seq[Filter],
+ options: Map[String, String]): (PartitionedFile) => Iterator[InternalRow] = {
verifySchema(dataSchema)
- val dataFiles = inputFiles.filterNot(_.getPath.getName startsWith "_")
-
- val path = if (dataFiles.length == 1) dataFiles(0).getPath.toUri.toString
- else if (dataFiles.isEmpty) throw new IOException("No input path specified for libsvm data")
- else throw new IOException("Multiple input paths are not supported for libsvm data.")
-
- val numFeatures = options.getOrElse("numFeatures", "-1").toInt
- val vectorType = options.getOrElse("vectorType", "sparse")
-
- val sc = sqlContext.sparkContext
- val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures)
- val sparse = vectorType == "sparse"
- baseRdd.map { pt =>
- val features = if (sparse) pt.features.toSparse else pt.features.toDense
- Row(pt.label, features)
- }.mapPartitions { externalRows =>
- val converter = RowEncoder(dataSchema)
- externalRows.map(converter.toRow)
+ val numFeatures = options("numFeatures").toInt
+ assert(numFeatures > 0)
+
+ val sparse = options.getOrElse("vectorType", "sparse") == "sparse"
+
+ val broadcastedConf = sqlContext.sparkContext.broadcast(
+ new SerializableConfiguration(new Configuration(sqlContext.sparkContext.hadoopConfiguration))
+ )
+
+ (file: PartitionedFile) => {
+ val points =
+ new HadoopFileLinesReader(file, broadcastedConf.value.value)
+ .map(_.toString.trim)
+ .filterNot(line => line.isEmpty || line.startsWith("#"))
+ .map { line =>
+ val (label, indices, values) = MLUtils.parseLibSVMRecord(line)
+ LabeledPoint(label, Vectors.sparse(numFeatures, indices, values))
+ }
+
+ val converter = RowEncoder(requiredSchema)
+
+ val unsafeRowIterator = points.map { pt =>
+ val features = if (sparse) pt.features.toSparse else pt.features.toDense
+ converter.toRow(Row(pt.label, features))
+ }
+
+ def toAttribute(f: StructField): AttributeReference =
+ AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
+
+ // Appends partition values
+ val fullOutput = (requiredSchema ++ partitionSchema).map(toAttribute)
+ val joinedRow = new JoinedRow()
+ val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput)
+
+ unsafeRowIterator.map { dataRow =>
+ appendPartitionColumns(joinedRow(dataRow, file.partitionValues))
+ }
}
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala
index 572815df0b..4e372702f0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.tree.impl
+package org.apache.spark.ml.tree.impl
import org.apache.commons.math3.distribution.PoissonDistribution
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
index c745e9f8db..61091bb803 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.tree.impl
+package org.apache.spark.ml.tree.impl
import org.apache.spark.mllib.tree.impurity._
@@ -86,6 +86,7 @@ private[spark] class DTStatsAggregator(
/**
* Get an [[ImpurityCalculator]] for a given (node, feature, bin).
+ *
* @param featureOffset This is a pre-computed (node, feature) offset
* from [[getFeatureOffset]].
*/
@@ -118,6 +119,7 @@ private[spark] class DTStatsAggregator(
/**
* Faster version of [[update]].
* Update the stats for a given (feature, bin), using the given label.
+ *
* @param featureOffset This is a pre-computed feature offset
* from [[getFeatureOffset]].
*/
@@ -138,6 +140,7 @@ private[spark] class DTStatsAggregator(
/**
* For a given feature, merge the stats for two bins.
+ *
* @param featureOffset This is a pre-computed feature offset
* from [[getFeatureOffset]].
* @param binIndex The other bin is merged into this bin.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
index 4f27dc44ef..c7cde1563f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.tree.impl
+package org.apache.spark.ml.tree.impl
import scala.collection.mutable
@@ -183,11 +183,16 @@ private[spark] object DecisionTreeMetadata extends Logging {
}
case _ => featureSubsetStrategy
}
+
+ val isIntRegex = "^([1-9]\\d*)$".r
+ val isFractionRegex = "^(0?\\.\\d*[1-9]\\d*|1\\.0+)$".r
val numFeaturesPerNode: Int = _featureSubsetStrategy match {
case "all" => numFeatures
case "sqrt" => math.sqrt(numFeatures).ceil.toInt
case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt)
case "onethird" => (numFeatures / 3.0).ceil.toInt
+ case isIntRegex(number) => if (BigInt(number) > numFeatures) numFeatures else number.toInt
+ case isFractionRegex(fraction) => (fraction.toDouble * numFeatures).ceil.toInt
}
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
index 1c8a9b4dfe..b6334762c7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
@@ -20,16 +20,17 @@ package org.apache.spark.ml.tree.impl
import org.apache.spark.internal.Logging
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy}
-import org.apache.spark.mllib.tree.impl.TimeTracker
import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance}
import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
-private[ml] object GradientBoostedTrees extends Logging {
+
+private[spark] object GradientBoostedTrees extends Logging {
/**
* Method to train a gradient boosting model
@@ -106,7 +107,7 @@ private[ml] object GradientBoostedTrees extends Logging {
initTree: DecisionTreeRegressionModel,
loss: OldLoss): RDD[(Double, Double)] = {
data.map { lp =>
- val pred = initTreeWeight * initTree.rootNode.predictImpl(lp.features).prediction
+ val pred = updatePrediction(lp.features, 0.0, initTree, initTreeWeight)
val error = loss.computeError(pred, lp.label)
(pred, error)
}
@@ -132,7 +133,7 @@ private[ml] object GradientBoostedTrees extends Logging {
val newPredError = data.zip(predictionAndError).mapPartitions { iter =>
iter.map { case (lp, (pred, error)) =>
- val newPred = pred + tree.rootNode.predictImpl(lp.features).prediction * treeWeight
+ val newPred = updatePrediction(lp.features, pred, tree, treeWeight)
val newError = loss.computeError(newPred, lp.label)
(newPred, newError)
}
@@ -141,6 +142,97 @@ private[ml] object GradientBoostedTrees extends Logging {
}
/**
+ * Add prediction from a new boosting iteration to an existing prediction.
+ *
+ * @param features Vector of features representing a single data point.
+ * @param prediction The existing prediction.
+ * @param tree New Decision Tree model.
+ * @param weight Tree weight.
+ * @return Updated prediction.
+ */
+ def updatePrediction(
+ features: Vector,
+ prediction: Double,
+ tree: DecisionTreeRegressionModel,
+ weight: Double): Double = {
+ prediction + tree.rootNode.predictImpl(features).prediction * weight
+ }
+
+ /**
+ * Method to calculate error of the base learner for the gradient boosting calculation.
+ * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
+ * purposes.
+ * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @param trees Boosted Decision Tree models
+ * @param treeWeights Learning rates at each boosting iteration.
+ * @param loss evaluation metric.
+ * @return Measure of model error on data
+ */
+ def computeError(
+ data: RDD[LabeledPoint],
+ trees: Array[DecisionTreeRegressionModel],
+ treeWeights: Array[Double],
+ loss: OldLoss): Double = {
+ data.map { lp =>
+ val predicted = trees.zip(treeWeights).foldLeft(0.0) { case (acc, (model, weight)) =>
+ updatePrediction(lp.features, acc, model, weight)
+ }
+ loss.computeError(predicted, lp.label)
+ }.mean()
+ }
+
+ /**
+ * Method to compute error or loss for every iteration of gradient boosting.
+ *
+ * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
+ * @param trees Boosted Decision Tree models
+ * @param treeWeights Learning rates at each boosting iteration.
+ * @param loss evaluation metric.
+ * @param algo algorithm for the ensemble, either Classification or Regression
+ * @return an array with index i having the losses or errors for the ensemble
+ * containing the first i+1 trees
+ */
+ def evaluateEachIteration(
+ data: RDD[LabeledPoint],
+ trees: Array[DecisionTreeRegressionModel],
+ treeWeights: Array[Double],
+ loss: OldLoss,
+ algo: OldAlgo.Value): Array[Double] = {
+
+ val sc = data.sparkContext
+ val remappedData = algo match {
+ case OldAlgo.Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ case _ => data
+ }
+
+ val numIterations = trees.length
+ val evaluationArray = Array.fill(numIterations)(0.0)
+ val localTreeWeights = treeWeights
+
+ var predictionAndError = computeInitialPredictionAndError(
+ remappedData, localTreeWeights(0), trees(0), loss)
+
+ evaluationArray(0) = predictionAndError.values.mean()
+
+ val broadcastTrees = sc.broadcast(trees)
+ (1 until numIterations).foreach { nTree =>
+ predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
+ val currentTree = broadcastTrees.value(nTree)
+ val currentTreeWeight = localTreeWeights(nTree)
+ iter.map { case (point, (pred, error)) =>
+ val newPred = updatePrediction(point.features, pred, currentTree, currentTreeWeight)
+ val newError = loss.computeError(newPred, point.label)
+ (newPred, newError)
+ }
+ }
+ evaluationArray(nTree) = predictionAndError.values.mean()
+ }
+
+ broadcastTrees.unpersist()
+ evaluationArray
+ }
+
+ /**
* Internal method for performing regression using trees as base learners.
* @param input training dataset
* @param validationInput validation dataset, ignored if validate is set to false.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
index 2c8286766f..9d697a36b6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
@@ -26,7 +26,6 @@ import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
import org.apache.spark.ml.tree.{LearningNode, Split}
-import org.apache.spark.mllib.tree.impl.BaggedPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 7774ae64e5..7b1fd089f2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -26,16 +26,12 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
-import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, DTStatsAggregator,
- TimeTracker}
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.ImpurityStats
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
@@ -332,7 +328,7 @@ private[spark] object RandomForest extends Logging {
/**
* Given a group of nodes, this finds the best split for each node.
*
- * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
+ * @param input Training data: RDD of [[org.apache.spark.ml.tree.impl.TreePoint]]
* @param metadata Learning and dataset metadata
* @param topNodes Root node for each tree. Used for matching instances with nodes.
* @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
@@ -1105,112 +1101,4 @@ private[spark] object RandomForest extends Logging {
}
}
- /**
- * Given a Random Forest model, compute the importance of each feature.
- * This generalizes the idea of "Gini" importance to other losses,
- * following the explanation of Gini importance from "Random Forests" documentation
- * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
- *
- * This feature importance is calculated as follows:
- * - Average over trees:
- * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
- * where gain is scaled by the number of instances passing through node
- * - Normalize importances for tree to sum to 1.
- * - Normalize feature importance vector to sum to 1.
- *
- * @param trees Unweighted forest of trees
- * @param numFeatures Number of features in model (even if not all are explicitly used by
- * the model).
- * If -1, then numFeatures is set based on the max feature index in all trees.
- * @return Feature importance values, of length numFeatures.
- */
- private[ml] def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = {
- val totalImportances = new OpenHashMap[Int, Double]()
- trees.foreach { tree =>
- // Aggregate feature importance vector for this tree
- val importances = new OpenHashMap[Int, Double]()
- computeFeatureImportance(tree.rootNode, importances)
- // Normalize importance vector for this tree, and add it to total.
- // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count?
- val treeNorm = importances.map(_._2).sum
- if (treeNorm != 0) {
- importances.foreach { case (idx, impt) =>
- val normImpt = impt / treeNorm
- totalImportances.changeValue(idx, normImpt, _ + normImpt)
- }
- }
- }
- // Normalize importances
- normalizeMapValues(totalImportances)
- // Construct vector
- val d = if (numFeatures != -1) {
- numFeatures
- } else {
- // Find max feature index used in trees
- val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
- maxFeatureIndex + 1
- }
- if (d == 0) {
- assert(totalImportances.size == 0, s"Unknown error in computing feature" +
- s" importance: No splits found, but some non-zero importances.")
- }
- val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
- Vectors.sparse(d, indices.toArray, values.toArray)
- }
-
- /**
- * Given a Decision Tree model, compute the importance of each feature.
- * This generalizes the idea of "Gini" importance to other losses,
- * following the explanation of Gini importance from "Random Forests" documentation
- * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
- *
- * This feature importance is calculated as follows:
- * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
- * where gain is scaled by the number of instances passing through node
- * - Normalize importances for tree to sum to 1.
- *
- * @param tree Decision tree to compute importances for.
- * @param numFeatures Number of features in model (even if not all are explicitly used by
- * the model).
- * If -1, then numFeatures is set based on the max feature index in all trees.
- * @return Feature importance values, of length numFeatures.
- */
- private[ml] def featureImportances(tree: DecisionTreeModel, numFeatures: Int): Vector = {
- featureImportances(Array(tree), numFeatures)
- }
-
- /**
- * Recursive method for computing feature importances for one tree.
- * This walks down the tree, adding to the importance of 1 feature at each node.
- * @param node Current node in recursion
- * @param importances Aggregate feature importances, modified by this method
- */
- private[impl] def computeFeatureImportance(
- node: Node,
- importances: OpenHashMap[Int, Double]): Unit = {
- node match {
- case n: InternalNode =>
- val feature = n.split.featureIndex
- val scaledGain = n.gain * n.impurityStats.count
- importances.changeValue(feature, scaledGain, _ + scaledGain)
- computeFeatureImportance(n.leftChild, importances)
- computeFeatureImportance(n.rightChild, importances)
- case n: LeafNode =>
- // do nothing
- }
- }
-
- /**
- * Normalize the values of this map to sum to 1, in place.
- * If all values are 0, this method does nothing.
- * @param map Map with non-negative values.
- */
- private[impl] def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = {
- val total = map.map(_._2).sum
- if (total != 0) {
- val keys = map.iterator.map(_._1).toArray
- keys.foreach { key => map.changeValue(key, 0.0, _ / total) }
- }
- }
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala
index 70afaa162b..4cc250aa46 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.tree.impl
+package org.apache.spark.ml.tree.impl
import scala.collection.mutable.{HashMap => MutableHashMap}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
index 9fa27e5e1f..3a2bf3c725 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
@@ -19,7 +19,6 @@ package org.apache.spark.ml.tree.impl
import org.apache.spark.ml.tree.{ContinuousSplit, Split}
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
import org.apache.spark.rdd.RDD
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index ef40c9068f..f38e1ec7c0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -17,16 +17,22 @@
package org.apache.spark.ml.tree
+import scala.reflect.ClassTag
+
import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.jackson.JsonMethods._
-import org.apache.spark.ml.param.Param
-import org.apache.spark.ml.util.DefaultParamsReader
+import org.apache.spark.ml.param.{Param, Params}
+import org.apache.spark.ml.tree.DecisionTreeModelReadWrite.NodeData
+import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter}
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Dataset, SQLContext}
+import org.apache.spark.util.collection.OpenHashMap
/**
* Abstraction for Decision Tree models.
@@ -70,7 +76,7 @@ private[spark] trait DecisionTreeModel {
*/
private[ml] def maxSplitFeatureIndex(): Int = rootNode.maxSplitFeatureIndex()
- /** Convert to spark.mllib DecisionTreeModel (losing some infomation) */
+ /** Convert to spark.mllib DecisionTreeModel (losing some information) */
private[spark] def toOld: OldDecisionTreeModel
}
@@ -78,14 +84,21 @@ private[spark] trait DecisionTreeModel {
* Abstraction for models which are ensembles of decision trees
*
* TODO: Add support for predicting probabilities and raw predictions SPARK-3727
+ *
+ * @tparam M Type of tree model in this ensemble
*/
-private[ml] trait TreeEnsembleModel {
+private[ml] trait TreeEnsembleModel[M <: DecisionTreeModel] {
// Note: We use getTrees since subclasses of TreeEnsembleModel will store subclasses of
// DecisionTreeModel.
/** Trees in this ensemble. Warning: These have null parent Estimators. */
- def trees: Array[DecisionTreeModel]
+ def trees: Array[M]
+
+ /**
+ * Number of trees in ensemble
+ */
+ val getNumTrees: Int = trees.length
/** Weights for each tree, zippable with [[trees]] */
def treeWeights: Array[Double]
@@ -97,7 +110,7 @@ private[ml] trait TreeEnsembleModel {
/** Summary of the model */
override def toString: String = {
// Implementing classes should generally override this method to be more descriptive.
- s"TreeEnsembleModel with $numTrees trees"
+ s"TreeEnsembleModel with ${trees.length} trees"
}
/** Full description of model */
@@ -108,13 +121,129 @@ private[ml] trait TreeEnsembleModel {
}.fold("")(_ + _)
}
- /** Number of trees in ensemble */
- val numTrees: Int = trees.length
-
/** Total number of nodes, summed over all trees in the ensemble. */
lazy val totalNumNodes: Int = trees.map(_.numNodes).sum
}
+private[ml] object TreeEnsembleModel {
+
+ /**
+ * Given a tree ensemble model, compute the importance of each feature.
+ * This generalizes the idea of "Gini" importance to other losses,
+ * following the explanation of Gini importance from "Random Forests" documentation
+ * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ *
+ * For collections of trees, including boosting and bagging, Hastie et al.
+ * propose to use the average of single tree importances across all trees in the ensemble.
+ *
+ * This feature importance is calculated as follows:
+ * - Average over trees:
+ * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ * where gain is scaled by the number of instances passing through node
+ * - Normalize importances for tree to sum to 1.
+ * - Normalize feature importance vector to sum to 1.
+ *
+ * References:
+ * - Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.
+ *
+ * @param trees Unweighted collection of trees
+ * @param numFeatures Number of features in model (even if not all are explicitly used by
+ * the model).
+ * If -1, then numFeatures is set based on the max feature index in all trees.
+ * @return Feature importance values, of length numFeatures.
+ */
+ def featureImportances[M <: DecisionTreeModel](trees: Array[M], numFeatures: Int): Vector = {
+ val totalImportances = new OpenHashMap[Int, Double]()
+ trees.foreach { tree =>
+ // Aggregate feature importance vector for this tree
+ val importances = new OpenHashMap[Int, Double]()
+ computeFeatureImportance(tree.rootNode, importances)
+ // Normalize importance vector for this tree, and add it to total.
+ // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count?
+ val treeNorm = importances.map(_._2).sum
+ if (treeNorm != 0) {
+ importances.foreach { case (idx, impt) =>
+ val normImpt = impt / treeNorm
+ totalImportances.changeValue(idx, normImpt, _ + normImpt)
+ }
+ }
+ }
+ // Normalize importances
+ normalizeMapValues(totalImportances)
+ // Construct vector
+ val d = if (numFeatures != -1) {
+ numFeatures
+ } else {
+ // Find max feature index used in trees
+ val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
+ maxFeatureIndex + 1
+ }
+ if (d == 0) {
+ assert(totalImportances.size == 0, s"Unknown error in computing feature" +
+ s" importance: No splits found, but some non-zero importances.")
+ }
+ val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
+ Vectors.sparse(d, indices.toArray, values.toArray)
+ }
+
+ /**
+ * Given a Decision Tree model, compute the importance of each feature.
+ * This generalizes the idea of "Gini" importance to other losses,
+ * following the explanation of Gini importance from "Random Forests" documentation
+ * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ *
+ * This feature importance is calculated as follows:
+ * - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ * where gain is scaled by the number of instances passing through node
+ * - Normalize importances for tree to sum to 1.
+ *
+ * @param tree Decision tree to compute importances for.
+ * @param numFeatures Number of features in model (even if not all are explicitly used by
+ * the model).
+ * If -1, then numFeatures is set based on the max feature index in all trees.
+ * @return Feature importance values, of length numFeatures.
+ */
+ def featureImportances[M <: DecisionTreeModel : ClassTag](tree: M, numFeatures: Int): Vector = {
+ featureImportances(Array(tree), numFeatures)
+ }
+
+ /**
+ * Recursive method for computing feature importances for one tree.
+ * This walks down the tree, adding to the importance of 1 feature at each node.
+ *
+ * @param node Current node in recursion
+ * @param importances Aggregate feature importances, modified by this method
+ */
+ def computeFeatureImportance(
+ node: Node,
+ importances: OpenHashMap[Int, Double]): Unit = {
+ node match {
+ case n: InternalNode =>
+ val feature = n.split.featureIndex
+ val scaledGain = n.gain * n.impurityStats.count
+ importances.changeValue(feature, scaledGain, _ + scaledGain)
+ computeFeatureImportance(n.leftChild, importances)
+ computeFeatureImportance(n.rightChild, importances)
+ case n: LeafNode =>
+ // do nothing
+ }
+ }
+
+ /**
+ * Normalize the values of this map to sum to 1, in place.
+ * If all values are 0, this method does nothing.
+ *
+ * @param map Map with non-negative values.
+ */
+ def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = {
+ val total = map.map(_._2).sum
+ if (total != 0) {
+ val keys = map.iterator.map(_._1).toArray
+ keys.foreach { key => map.changeValue(key, 0.0, _ / total) }
+ }
+ }
+}
+
/** Helper classes for tree model persistence */
private[ml] object DecisionTreeModelReadWrite {
@@ -196,6 +325,10 @@ private[ml] object DecisionTreeModelReadWrite {
}
}
+ /**
+ * Load a decision tree from a file.
+ * @return Root node of reconstructed tree
+ */
def loadTreeNodes(
path: String,
metadata: DefaultParamsReader.Metadata,
@@ -211,9 +344,18 @@ private[ml] object DecisionTreeModelReadWrite {
val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).as[NodeData]
+ buildTreeFromNodes(data.collect(), impurityType)
+ }
+ /**
+ * Given all data for all nodes in a tree, rebuild the tree.
+ * @param data Unsorted node data
+ * @param impurityType Impurity type for this tree
+ * @return Root node of reconstructed tree
+ */
+ def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = {
// Load all nodes, sorted by ID.
- val nodes: Array[NodeData] = data.collect().sortBy(_.id)
+ val nodes = data.sortBy(_.id)
// Sanity checks; could remove
assert(nodes.head.id == 0, s"Decision Tree load failed. Expected smallest node ID to be 0," +
s" but found ${nodes.head.id}")
@@ -238,3 +380,105 @@ private[ml] object DecisionTreeModelReadWrite {
finalNodes.head
}
}
+
+private[ml] object EnsembleModelReadWrite {
+
+ /**
+ * Helper method for saving a tree ensemble to disk.
+ *
+ * @param instance Tree ensemble model
+ * @param path Path to which to save the ensemble model.
+ * @param extraMetadata Metadata such as numFeatures, numClasses, numTrees.
+ */
+ def saveImpl[M <: Params with TreeEnsembleModel[_ <: DecisionTreeModel]](
+ instance: M,
+ path: String,
+ sql: SQLContext,
+ extraMetadata: JObject): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sql.sparkContext, Some(extraMetadata))
+ val treesMetadataWeights: Array[(Int, String, Double)] = instance.trees.zipWithIndex.map {
+ case (tree, treeID) =>
+ (treeID,
+ DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sql.sparkContext),
+ instance.treeWeights(treeID))
+ }
+ val treesMetadataPath = new Path(path, "treesMetadata").toString
+ sql.createDataFrame(treesMetadataWeights).toDF("treeID", "metadata", "weights")
+ .write.parquet(treesMetadataPath)
+ val dataPath = new Path(path, "data").toString
+ val nodeDataRDD = sql.sparkContext.parallelize(instance.trees.zipWithIndex).flatMap {
+ case (tree, treeID) => EnsembleNodeData.build(tree, treeID)
+ }
+ sql.createDataFrame(nodeDataRDD).write.parquet(dataPath)
+ }
+
+ /**
+ * Helper method for loading a tree ensemble from disk.
+ * This reconstructs all trees, returning the root nodes.
+ * @param path Path given to [[saveImpl()]]
+ * @param className Class name for ensemble model type
+ * @param treeClassName Class name for tree model type in the ensemble
+ * @return (ensemble metadata, array over trees of (tree metadata, root node)),
+ * where the root node is linked with all descendents
+ * @see [[saveImpl()]] for how the model was saved
+ */
+ def loadImpl(
+ path: String,
+ sql: SQLContext,
+ className: String,
+ treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = {
+ import sql.implicits._
+ implicit val format = DefaultFormats
+ val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className)
+
+ // Get impurity to construct ImpurityCalculator for each node
+ val impurityType: String = {
+ val impurityJson: JValue = metadata.getParamValue("impurity")
+ Param.jsonDecode[String](compact(render(impurityJson)))
+ }
+
+ val treesMetadataPath = new Path(path, "treesMetadata").toString
+ val treesMetadataRDD: RDD[(Int, (Metadata, Double))] = sql.read.parquet(treesMetadataPath)
+ .select("treeID", "metadata", "weights").as[(Int, String, Double)].rdd.map {
+ case (treeID: Int, json: String, weights: Double) =>
+ treeID -> (DefaultParamsReader.parseMetadata(json, treeClassName), weights)
+ }
+
+ val treesMetadataWeights = treesMetadataRDD.sortByKey().values.collect()
+ val treesMetadata = treesMetadataWeights.map(_._1)
+ val treesWeights = treesMetadataWeights.map(_._2)
+
+ val dataPath = new Path(path, "data").toString
+ val nodeData: Dataset[EnsembleNodeData] =
+ sql.read.parquet(dataPath).as[EnsembleNodeData]
+ val rootNodesRDD: RDD[(Int, Node)] =
+ nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map {
+ case (treeID: Int, nodeData: Iterable[NodeData]) =>
+ treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType)
+ }
+ val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect()
+ (metadata, treesMetadata.zip(rootNodes), treesWeights)
+ }
+
+ /**
+ * Info for one [[Node]] in a tree ensemble
+ *
+ * @param treeID Tree index
+ * @param nodeData Data for this node
+ */
+ case class EnsembleNodeData(
+ treeID: Int,
+ nodeData: NodeData)
+
+ object EnsembleNodeData {
+ /**
+ * Create [[EnsembleNodeData]] instances for the given tree.
+ *
+ * @return Sequence of nodes for this tree
+ */
+ def build(tree: DecisionTreeModel, treeID: Int): Seq[EnsembleNodeData] = {
+ val (nodeData: Seq[NodeData], _) = NodeData.build(tree.rootNode, 0)
+ nodeData.map(nd => EnsembleNodeData(treeID, nd))
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 4fbd957677..b6783911ad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
-import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
+import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
/**
@@ -315,22 +315,8 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
}
}
-/**
- * Parameters for Random Forest algorithms.
- *
- * Note: Marked as private and DeveloperApi since this may be made public in the future.
- */
-private[ml] trait RandomForestParams extends TreeEnsembleParams {
-
- /**
- * Number of trees to train (>= 1).
- * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
- * TODO: Change to always do bootstrapping (simpler). SPARK-7130
- * (default = 20)
- * @group param
- */
- final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
- ParamValidators.gtEq(1))
+/** Used for [[RandomForestParams]] */
+private[ml] trait HasFeatureSubsetStrategy extends Params {
/**
* The number of features to consider for splits at each tree node.
@@ -343,6 +329,8 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
* - "onethird": use 1/3 of the features
* - "sqrt": use sqrt(number of features)
* - "log2": use log2(number of features)
+ * - "n": when n is in the range (0, 1.0], use n * number of features. When n
+ * is in the range (1, number of features), use n features.
* (default = "auto")
*
* These various settings are based on the following references:
@@ -360,29 +348,71 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
"The number of features to consider for splits at each tree node." +
s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}",
(value: String) =>
- RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase))
+ RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase)
+ || value.matches(RandomForestParams.supportedFeatureSubsetStrategiesRegex))
- setDefault(numTrees -> 20, featureSubsetStrategy -> "auto")
+ setDefault(featureSubsetStrategy -> "auto")
/** @group setParam */
- def setNumTrees(value: Int): this.type = set(numTrees, value)
+ def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value)
/** @group getParam */
- final def getNumTrees: Int = $(numTrees)
+ final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase
+}
+
+/**
+ * Used for [[RandomForestParams]].
+ * This is separated out from [[RandomForestParams]] because of an issue with the
+ * `numTrees` method conflicting with this Param in the Estimator.
+ */
+private[ml] trait HasNumTrees extends Params {
+
+ /**
+ * Number of trees to train (>= 1).
+ * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
+ * TODO: Change to always do bootstrapping (simpler). SPARK-7130
+ * (default = 20)
+ * @group param
+ */
+ final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
+ ParamValidators.gtEq(1))
+
+ setDefault(numTrees -> 20)
/** @group setParam */
- def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value)
+ def setNumTrees(value: Int): this.type = set(numTrees, value)
/** @group getParam */
- final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase
+ final def getNumTrees: Int = $(numTrees)
}
+/**
+ * Parameters for Random Forest algorithms.
+ */
+private[ml] trait RandomForestParams extends TreeEnsembleParams
+ with HasFeatureSubsetStrategy with HasNumTrees
+
private[spark] object RandomForestParams {
// These options should be lowercase.
final val supportedFeatureSubsetStrategies: Array[String] =
Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase)
+
+ // The regex to capture "(0.0-1.0]", and "n" for integer 0 < n <= (number of features)
+ final val supportedFeatureSubsetStrategiesRegex = "^(?:[1-9]\\d*|0?\\.\\d*[1-9]\\d*|1\\.0+)$"
}
+private[ml] trait RandomForestClassifierParams
+ extends RandomForestParams with TreeClassifierParams
+
+private[ml] trait RandomForestClassificationModelParams extends TreeEnsembleParams
+ with HasFeatureSubsetStrategy with TreeClassifierParams
+
+private[ml] trait RandomForestRegressorParams
+ extends RandomForestParams with TreeRegressorParams
+
+private[ml] trait RandomForestRegressionModelParams extends TreeEnsembleParams
+ with HasFeatureSubsetStrategy with TreeRegressorParams
+
/**
* Parameters for Gradient-Boosted Tree algorithms.
*
@@ -432,3 +462,74 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
/** Get old Gradient Boosting Loss type */
private[ml] def getOldLossType: OldLoss
}
+
+private[ml] object GBTClassifierParams {
+ // The losses below should be lowercase.
+ /** Accessor for supported loss settings: logistic */
+ final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase)
+}
+
+private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams {
+
+ /**
+ * Loss function which GBT tries to minimize. (case-insensitive)
+ * Supported: "logistic"
+ * (default = logistic)
+ * @group param
+ */
+ val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
+ " tries to minimize (case-insensitive). Supported options:" +
+ s" ${GBTClassifierParams.supportedLossTypes.mkString(", ")}",
+ (value: String) => GBTClassifierParams.supportedLossTypes.contains(value.toLowerCase))
+
+ setDefault(lossType -> "logistic")
+
+ /** @group getParam */
+ def getLossType: String = $(lossType).toLowerCase
+
+ /** (private[ml]) Convert new loss to old loss. */
+ override private[ml] def getOldLossType: OldLoss = {
+ getLossType match {
+ case "logistic" => OldLogLoss
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType")
+ }
+ }
+}
+
+private[ml] object GBTRegressorParams {
+ // The losses below should be lowercase.
+ /** Accessor for supported loss settings: squared (L2), absolute (L1) */
+ final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase)
+}
+
+private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams {
+
+ /**
+ * Loss function which GBT tries to minimize. (case-insensitive)
+ * Supported: "squared" (L2) and "absolute" (L1)
+ * (default = squared)
+ * @group param
+ */
+ val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
+ " tries to minimize (case-insensitive). Supported options:" +
+ s" ${GBTRegressorParams.supportedLossTypes.mkString(", ")}",
+ (value: String) => GBTRegressorParams.supportedLossTypes.contains(value.toLowerCase))
+
+ setDefault(lossType -> "squared")
+
+ /** @group getParam */
+ def getLossType: String = $(lossType).toLowerCase
+
+ /** (private[ml]) Convert new loss to old loss. */
+ override private[ml] def getOldLossType: OldLoss = {
+ getLossType match {
+ case "squared" => OldSquaredError
+ case "absolute" => OldAbsoluteError
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType")
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 963f81cb3e..de563d4fad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -17,27 +17,25 @@
package org.apache.spark.ml.tuning
+import java.util.{List => JList}
+
+import scala.collection.JavaConverters._
+
import com.github.fommil.netlib.F2jBLAS
import org.apache.hadoop.fs.Path
-import org.json4s.{DefaultFormats, JObject}
-import org.json4s.jackson.JsonMethods._
+import org.json4s.DefaultFormats
-import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml._
-import org.apache.spark.ml.classification.OneVsRestParams
import org.apache.spark.ml.evaluation.Evaluator
-import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.HasSeed
import org.apache.spark.ml.util._
-import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.util.MLUtils
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
-
/**
* Params for [[CrossValidator]] and [[CrossValidatorModel]].
*/
@@ -45,6 +43,7 @@ private[ml] trait CrossValidatorParams extends ValidatorParams with HasSeed {
/**
* Param for number of folds for cross validation. Must be >= 2.
* Default: 3
+ *
* @group param
*/
val numFolds: IntParam = new IntParam(this, "numFolds",
@@ -91,8 +90,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
@Since("2.0.0")
def setSeed(value: Long): this.type = set(seed, value)
- @Since("1.4.0")
- override def fit(dataset: DataFrame): CrossValidatorModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): CrossValidatorModel = {
val schema = dataset.schema
transformSchema(schema, logging = true)
val sqlCtx = dataset.sqlContext
@@ -101,7 +100,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
val epm = $(estimatorParamMaps)
val numModels = epm.length
val metrics = new Array[Double](epm.length)
- val splits = MLUtils.kFold(dataset.rdd, $(numFolds), $(seed))
+ val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))
splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
@@ -163,10 +162,10 @@ object CrossValidator extends MLReadable[CrossValidator] {
private[CrossValidator] class CrossValidatorWriter(instance: CrossValidator) extends MLWriter {
- SharedReadWrite.validateParams(instance)
+ ValidatorParams.validateParams(instance)
override protected def saveImpl(path: String): Unit =
- SharedReadWrite.saveImpl(path, instance, sc)
+ ValidatorParams.saveImpl(path, instance, sc)
}
private class CrossValidatorReader extends MLReader[CrossValidator] {
@@ -175,8 +174,11 @@ object CrossValidator extends MLReadable[CrossValidator] {
private val className = classOf[CrossValidator].getName
override def load(path: String): CrossValidator = {
- val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
- SharedReadWrite.load(path, sc, className)
+ implicit val format = DefaultFormats
+
+ val (metadata, estimator, evaluator, estimatorParamMaps) =
+ ValidatorParams.loadImpl(path, sc, className)
+ val numFolds = (metadata.params \ "numFolds").extract[Int]
new CrossValidator(metadata.uid)
.setEstimator(estimator)
.setEvaluator(evaluator)
@@ -184,123 +186,6 @@ object CrossValidator extends MLReadable[CrossValidator] {
.setNumFolds(numFolds)
}
}
-
- private object CrossValidatorReader {
- /**
- * Examine the given estimator (which may be a compound estimator) and extract a mapping
- * from UIDs to corresponding [[Params]] instances.
- */
- def getUidMap(instance: Params): Map[String, Params] = {
- val uidList = getUidMapImpl(instance)
- val uidMap = uidList.toMap
- if (uidList.size != uidMap.size) {
- throw new RuntimeException("CrossValidator.load found a compound estimator with stages" +
- s" with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}")
- }
- uidMap
- }
-
- def getUidMapImpl(instance: Params): List[(String, Params)] = {
- val subStages: Array[Params] = instance match {
- case p: Pipeline => p.getStages.asInstanceOf[Array[Params]]
- case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]]
- case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator)
- case ovr: OneVsRestParams =>
- // TODO: SPARK-11892: This case may require special handling.
- throw new UnsupportedOperationException("CrossValidator write will fail because it" +
- " cannot yet handle an estimator containing type: ${ovr.getClass.getName}")
- case rformModel: RFormulaModel => Array(rformModel.pipelineModel)
- case _: Params => Array()
- }
- val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _)
- List((instance.uid, instance)) ++ subStageMaps
- }
- }
-
- private[tuning] object SharedReadWrite {
-
- /**
- * Check that [[CrossValidator.evaluator]] and [[CrossValidator.estimator]] are Writable.
- * This does not check [[CrossValidator.estimatorParamMaps]].
- */
- def validateParams(instance: ValidatorParams): Unit = {
- def checkElement(elem: Params, name: String): Unit = elem match {
- case stage: MLWritable => // good
- case other =>
- throw new UnsupportedOperationException("CrossValidator write will fail " +
- s" because it contains $name which does not implement Writable." +
- s" Non-Writable $name: ${other.uid} of type ${other.getClass}")
- }
- checkElement(instance.getEvaluator, "evaluator")
- checkElement(instance.getEstimator, "estimator")
- // Check to make sure all Params apply to this estimator. Throw an error if any do not.
- // Extraneous Params would cause problems when loading the estimatorParamMaps.
- val uidToInstance: Map[String, Params] = CrossValidatorReader.getUidMap(instance)
- instance.getEstimatorParamMaps.foreach { case pMap: ParamMap =>
- pMap.toSeq.foreach { case ParamPair(p, v) =>
- require(uidToInstance.contains(p.parent), s"CrossValidator save requires all Params in" +
- s" estimatorParamMaps to apply to this CrossValidator, its Estimator, or its" +
- s" Evaluator. An extraneous Param was found: $p")
- }
- }
- }
-
- private[tuning] def saveImpl(
- path: String,
- instance: CrossValidatorParams,
- sc: SparkContext,
- extraMetadata: Option[JObject] = None): Unit = {
- import org.json4s.JsonDSL._
-
- val estimatorParamMapsJson = compact(render(
- instance.getEstimatorParamMaps.map { case paramMap =>
- paramMap.toSeq.map { case ParamPair(p, v) =>
- Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v))
- }
- }.toSeq
- ))
- val jsonParams = List(
- "numFolds" -> parse(instance.numFolds.jsonEncode(instance.getNumFolds)),
- "estimatorParamMaps" -> parse(estimatorParamMapsJson)
- )
- DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
-
- val evaluatorPath = new Path(path, "evaluator").toString
- instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath)
- val estimatorPath = new Path(path, "estimator").toString
- instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath)
- }
-
- private[tuning] def load[M <: Model[M]](
- path: String,
- sc: SparkContext,
- expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap], Int) = {
-
- val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
-
- implicit val format = DefaultFormats
- val evaluatorPath = new Path(path, "evaluator").toString
- val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc)
- val estimatorPath = new Path(path, "estimator").toString
- val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc)
-
- val uidToParams = Map(evaluator.uid -> evaluator) ++ CrossValidatorReader.getUidMap(estimator)
-
- val numFolds = (metadata.params \ "numFolds").extract[Int]
- val estimatorParamMaps: Array[ParamMap] =
- (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map {
- pMap =>
- val paramPairs = pMap.map { case pInfo: Map[String, String] =>
- val est = uidToParams(pInfo("parent"))
- val param = est.getParam(pInfo("name"))
- val value = param.jsonDecode(pInfo("value"))
- param -> value
- }
- ParamMap(paramPairs: _*)
- }.toArray
- (metadata, estimator, evaluator, estimatorParamMaps, numFolds)
- }
- }
}
/**
@@ -319,8 +204,13 @@ class CrossValidatorModel private[ml] (
@Since("1.5.0") val avgMetrics: Array[Double])
extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable {
- @Since("1.4.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ /** A Python-friendly auxiliary constructor. */
+ private[ml] def this(uid: String, bestModel: Model[_], avgMetrics: JList[Double]) = {
+ this(uid, bestModel, avgMetrics.asScala.toArray)
+ }
+
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
bestModel.transform(dataset)
}
@@ -346,8 +236,6 @@ class CrossValidatorModel private[ml] (
@Since("1.6.0")
object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
- import CrossValidator.SharedReadWrite
-
@Since("1.6.0")
override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader
@@ -357,12 +245,12 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
private[CrossValidatorModel]
class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter {
- SharedReadWrite.validateParams(instance)
+ ValidatorParams.validateParams(instance)
override protected def saveImpl(path: String): Unit = {
import org.json4s.JsonDSL._
val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq
- SharedReadWrite.saveImpl(path, instance, sc, Some(extraMetadata))
+ ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata))
val bestModelPath = new Path(path, "bestModel").toString
instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
}
@@ -376,8 +264,9 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
override def load(path: String): CrossValidatorModel = {
implicit val format = DefaultFormats
- val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
- SharedReadWrite.load(path, sc, className)
+ val (metadata, estimator, evaluator, estimatorParamMaps) =
+ ValidatorParams.loadImpl(path, sc, className)
+ val numFolds = (metadata.params \ "numFolds").extract[Int]
val bestModelPath = new Path(path, "bestModel").toString
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 70fa5f0234..12d6905510 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -17,22 +17,32 @@
package org.apache.spark.ml.tuning
+import java.util.{List => JList}
+
+import scala.collection.JavaConverters._
+import scala.language.existentials
+
+import org.apache.hadoop.fs.Path
+import org.json4s.DefaultFormats
+
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
-import org.apache.spark.ml.util.Identifiable
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.ml.param.shared.HasSeed
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
/**
* Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]].
*/
-private[ml] trait TrainValidationSplitParams extends ValidatorParams {
+private[ml] trait TrainValidationSplitParams extends ValidatorParams with HasSeed {
/**
* Param for ratio between train and validation data. Must be between 0 and 1.
* Default: 0.75
+ *
* @group param
*/
val trainRatio: DoubleParam = new DoubleParam(this, "trainRatio",
@@ -55,7 +65,7 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams {
@Experimental
class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: String)
extends Estimator[TrainValidationSplitModel]
- with TrainValidationSplitParams with Logging {
+ with TrainValidationSplitParams with MLWritable with Logging {
@Since("1.5.0")
def this() = this(Identifiable.randomUID("tvs"))
@@ -76,8 +86,12 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
@Since("1.5.0")
def setTrainRatio(value: Double): this.type = set(trainRatio, value)
- @Since("1.5.0")
- override def fit(dataset: DataFrame): TrainValidationSplitModel = {
+ /** @group setParam */
+ @Since("2.0.0")
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): TrainValidationSplitModel = {
val schema = dataset.schema
transformSchema(schema, logging = true)
val sqlCtx = dataset.sqlContext
@@ -87,10 +101,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
val numModels = epm.length
val metrics = new Array[Double](epm.length)
- val Array(training, validation) =
- dataset.rdd.randomSplit(Array($(trainRatio), 1 - $(trainRatio)))
- val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
- val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
+ val Array(trainingDataset, validationDataset) =
+ dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed))
+ trainingDataset.cache()
+ validationDataset.cache()
// multi-model training
logDebug(s"Train split with multiple sets of parameters.")
@@ -130,6 +144,47 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
}
copied
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new TrainValidationSplit.TrainValidationSplitWriter(this)
+}
+
+@Since("2.0.0")
+object TrainValidationSplit extends MLReadable[TrainValidationSplit] {
+
+ @Since("2.0.0")
+ override def read: MLReader[TrainValidationSplit] = new TrainValidationSplitReader
+
+ @Since("2.0.0")
+ override def load(path: String): TrainValidationSplit = super.load(path)
+
+ private[TrainValidationSplit] class TrainValidationSplitWriter(instance: TrainValidationSplit)
+ extends MLWriter {
+
+ ValidatorParams.validateParams(instance)
+
+ override protected def saveImpl(path: String): Unit =
+ ValidatorParams.saveImpl(path, instance, sc)
+ }
+
+ private class TrainValidationSplitReader extends MLReader[TrainValidationSplit] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[TrainValidationSplit].getName
+
+ override def load(path: String): TrainValidationSplit = {
+ implicit val format = DefaultFormats
+
+ val (metadata, estimator, evaluator, estimatorParamMaps) =
+ ValidatorParams.loadImpl(path, sc, className)
+ val trainRatio = (metadata.params \ "trainRatio").extract[Double]
+ new TrainValidationSplit(metadata.uid)
+ .setEstimator(estimator)
+ .setEvaluator(evaluator)
+ .setEstimatorParamMaps(estimatorParamMaps)
+ .setTrainRatio(trainRatio)
+ }
+ }
}
/**
@@ -146,10 +201,15 @@ class TrainValidationSplitModel private[ml] (
@Since("1.5.0") override val uid: String,
@Since("1.5.0") val bestModel: Model[_],
@Since("1.5.0") val validationMetrics: Array[Double])
- extends Model[TrainValidationSplitModel] with TrainValidationSplitParams {
+ extends Model[TrainValidationSplitModel] with TrainValidationSplitParams with MLWritable {
- @Since("1.5.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ /** A Python-friendly auxiliary constructor. */
+ private[ml] def this(uid: String, bestModel: Model[_], validationMetrics: JList[Double]) = {
+ this(uid, bestModel, validationMetrics.asScala.toArray)
+ }
+
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
bestModel.transform(dataset)
}
@@ -167,4 +227,53 @@ class TrainValidationSplitModel private[ml] (
validationMetrics.clone())
copyValues(copied, extra)
}
+
+ @Since("2.0.0")
+ override def write: MLWriter = new TrainValidationSplitModel.TrainValidationSplitModelWriter(this)
+}
+
+@Since("2.0.0")
+object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
+
+ @Since("2.0.0")
+ override def read: MLReader[TrainValidationSplitModel] = new TrainValidationSplitModelReader
+
+ @Since("2.0.0")
+ override def load(path: String): TrainValidationSplitModel = super.load(path)
+
+ private[TrainValidationSplitModel]
+ class TrainValidationSplitModelWriter(instance: TrainValidationSplitModel) extends MLWriter {
+
+ ValidatorParams.validateParams(instance)
+
+ override protected def saveImpl(path: String): Unit = {
+ import org.json4s.JsonDSL._
+ val extraMetadata = "validationMetrics" -> instance.validationMetrics.toSeq
+ ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata))
+ val bestModelPath = new Path(path, "bestModel").toString
+ instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
+ }
+ }
+
+ private class TrainValidationSplitModelReader extends MLReader[TrainValidationSplitModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = classOf[TrainValidationSplitModel].getName
+
+ override def load(path: String): TrainValidationSplitModel = {
+ implicit val format = DefaultFormats
+
+ val (metadata, estimator, evaluator, estimatorParamMaps) =
+ ValidatorParams.loadImpl(path, sc, className)
+ val trainRatio = (metadata.params \ "trainRatio").extract[Double]
+ val bestModelPath = new Path(path, "bestModel").toString
+ val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
+ val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray
+ val tvs = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics)
+ tvs.set(tvs.estimator, estimator)
+ .set(tvs.evaluator, evaluator)
+ .set(tvs.estimatorParamMaps, estimatorParamMaps)
+ .set(tvs.trainRatio, trainRatio)
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
index 953456e8f0..7a4e106aeb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
@@ -17,9 +17,17 @@
package org.apache.spark.ml.tuning
-import org.apache.spark.ml.Estimator
+import org.apache.hadoop.fs.Path
+import org.json4s.{DefaultFormats, _}
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkContext
+import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator
-import org.apache.spark.ml.param.{Param, ParamMap, Params}
+import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
+import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, MetaAlgorithmReadWrite,
+ MLWritable}
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.sql.types.StructType
/**
@@ -69,3 +77,108 @@ private[ml] trait ValidatorParams extends Params {
est.copy(firstEstimatorParamMap).transformSchema(schema)
}
}
+
+private[ml] object ValidatorParams {
+ /**
+ * Check that [[ValidatorParams.evaluator]] and [[ValidatorParams.estimator]] are Writable.
+ * This does not check [[ValidatorParams.estimatorParamMaps]].
+ */
+ def validateParams(instance: ValidatorParams): Unit = {
+ def checkElement(elem: Params, name: String): Unit = elem match {
+ case stage: MLWritable => // good
+ case other =>
+ throw new UnsupportedOperationException(instance.getClass.getName + " write will fail " +
+ s" because it contains $name which does not implement Writable." +
+ s" Non-Writable $name: ${other.uid} of type ${other.getClass}")
+ }
+ checkElement(instance.getEvaluator, "evaluator")
+ checkElement(instance.getEstimator, "estimator")
+ // Check to make sure all Params apply to this estimator. Throw an error if any do not.
+ // Extraneous Params would cause problems when loading the estimatorParamMaps.
+ val uidToInstance: Map[String, Params] = MetaAlgorithmReadWrite.getUidMap(instance)
+ instance.getEstimatorParamMaps.foreach { case pMap: ParamMap =>
+ pMap.toSeq.foreach { case ParamPair(p, v) =>
+ require(uidToInstance.contains(p.parent), s"ValidatorParams save requires all Params in" +
+ s" estimatorParamMaps to apply to this ValidatorParams, its Estimator, or its" +
+ s" Evaluator. An extraneous Param was found: $p")
+ }
+ }
+ }
+
+ /**
+ * Generic implementation of save for [[ValidatorParams]] types.
+ * This handles all [[ValidatorParams]] fields and saves [[Param]] values, but the implementing
+ * class needs to handle model data.
+ */
+ def saveImpl(
+ path: String,
+ instance: ValidatorParams,
+ sc: SparkContext,
+ extraMetadata: Option[JObject] = None): Unit = {
+ import org.json4s.JsonDSL._
+
+ val estimatorParamMapsJson = compact(render(
+ instance.getEstimatorParamMaps.map { case paramMap =>
+ paramMap.toSeq.map { case ParamPair(p, v) =>
+ Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v))
+ }
+ }.toSeq
+ ))
+
+ val validatorSpecificParams = instance match {
+ case cv: CrossValidatorParams =>
+ List("numFolds" -> parse(cv.numFolds.jsonEncode(cv.getNumFolds)))
+ case tvs: TrainValidationSplitParams =>
+ List("trainRatio" -> parse(tvs.trainRatio.jsonEncode(tvs.getTrainRatio)))
+ case _ =>
+ // This should not happen.
+ throw new NotImplementedError("ValidatorParams.saveImpl does not handle type: " +
+ instance.getClass.getCanonicalName)
+ }
+
+ val jsonParams = validatorSpecificParams ++ List(
+ "estimatorParamMaps" -> parse(estimatorParamMapsJson))
+
+ DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
+
+ val evaluatorPath = new Path(path, "evaluator").toString
+ instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath)
+ val estimatorPath = new Path(path, "estimator").toString
+ instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath)
+ }
+
+ /**
+ * Generic implementation of load for [[ValidatorParams]] types.
+ * This handles all [[ValidatorParams]] fields, but the implementing
+ * class needs to handle model data and special [[Param]] values.
+ */
+ def loadImpl[M <: Model[M]](
+ path: String,
+ sc: SparkContext,
+ expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap]) = {
+
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
+
+ implicit val format = DefaultFormats
+ val evaluatorPath = new Path(path, "evaluator").toString
+ val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc)
+ val estimatorPath = new Path(path, "estimator").toString
+ val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc)
+
+ val uidToParams = Map(evaluator.uid -> evaluator) ++ MetaAlgorithmReadWrite.getUidMap(estimator)
+
+ val estimatorParamMaps: Array[ParamMap] =
+ (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map {
+ pMap =>
+ val paramPairs = pMap.map { case pInfo: Map[String, String] =>
+ val est = uidToParams(pInfo("parent"))
+ val param = est.getParam(pInfo("name"))
+ val value = param.jsonDecode(pInfo("value"))
+ param -> value
+ }
+ ParamMap(paramPairs: _*)
+ }.toArray
+
+ (metadata, estimator, evaluator, estimatorParamMaps)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
new file mode 100644
index 0000000000..7e57cefc44
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import java.util.concurrent.atomic.AtomicLong
+
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param.Param
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.Dataset
+
+/**
+ * A small wrapper that defines a training session for an estimator, and some methods to log
+ * useful information during this session.
+ *
+ * A new instance is expected to be created within fit().
+ *
+ * @param estimator the estimator that is being fit
+ * @param dataset the training dataset
+ * @tparam E the type of the estimator
+ */
+private[ml] class Instrumentation[E <: Estimator[_]] private (
+ estimator: E, dataset: RDD[_]) extends Logging {
+
+ private val id = Instrumentation.counter.incrementAndGet()
+ private val prefix = {
+ val className = estimator.getClass.getSimpleName
+ s"$className-${estimator.uid}-${dataset.hashCode()}-$id: "
+ }
+
+ init()
+
+ private def init(): Unit = {
+ log(s"training: numPartitions=${dataset.partitions.length}" +
+ s" storageLevel=${dataset.getStorageLevel}")
+ }
+
+ /**
+ * Logs a message with a prefix that uniquely identifies the training session.
+ */
+ def log(msg: String): Unit = {
+ logInfo(prefix + msg)
+ }
+
+ /**
+ * Logs the value of the given parameters for the estimator being used in this session.
+ */
+ def logParams(params: Param[_]*): Unit = {
+ val pairs: Seq[(String, JValue)] = for {
+ p <- params
+ value <- estimator.get(p)
+ } yield {
+ val cast = p.asInstanceOf[Param[Any]]
+ p.name -> parse(cast.jsonEncode(value))
+ }
+ log(compact(render(map2jvalue(pairs.toMap))))
+ }
+
+ def logNumFeatures(num: Long): Unit = {
+ log(compact(render("numFeatures" -> num)))
+ }
+
+ def logNumClasses(num: Long): Unit = {
+ log(compact(render("numClasses" -> num)))
+ }
+
+ /**
+ * Logs the successful completion of the training session and the value of the learned model.
+ */
+ def logSuccess(model: Model[_]): Unit = {
+ log(s"training finished")
+ }
+}
+
+/**
+ * Some common methods for logging information about a training session.
+ */
+private[ml] object Instrumentation {
+ private val counter = new AtomicLong(0)
+
+ /**
+ * Creates an instrumentation object for a training session.
+ */
+ def create[E <: Estimator[_]](
+ estimator: E, dataset: Dataset[_]): Instrumentation[E] = {
+ create[E](estimator, dataset.rdd)
+ }
+
+ /**
+ * Creates an instrumentation object for a training session.
+ */
+ def create[E <: Estimator[_]](
+ estimator: E, dataset: RDD[_]): Instrumentation[E] = {
+ new Instrumentation[E](estimator, dataset)
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index c95e536abd..7dec07ea14 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -21,13 +21,18 @@ import java.io.IOException
import org.apache.hadoop.fs.Path
import org.json4s._
-import org.json4s.jackson.JsonMethods._
+import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
+import org.apache.spark.ml._
+import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel}
+import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.ml.param.{ParamPair, Params}
+import org.apache.spark.ml.tuning.ValidatorParams
import org.apache.spark.sql.SQLContext
import org.apache.spark.util.Utils
@@ -139,6 +144,7 @@ private[ml] trait DefaultParamsWritable extends MLWritable { self: Params =>
/**
* Abstract class for utility classes that can load ML instances.
+ *
* @tparam T ML instance type
*/
@Experimental
@@ -157,6 +163,7 @@ abstract class MLReader[T] extends BaseReadWrite {
/**
* Trait for objects that provide [[MLReader]].
+ *
* @tparam T ML instance type
*/
@Experimental
@@ -187,6 +194,7 @@ private[ml] trait DefaultParamsReadable[T] extends MLReadable[T] {
* Default [[MLWriter]] implementation for transformers and estimators that contain basic
* (json4s-serializable) params and no data. This will not handle more complex params or types with
* data (e.g., models with coefficients).
+ *
* @param instance object to save
*/
private[ml] class DefaultParamsWriter(instance: Params) extends MLWriter {
@@ -206,6 +214,7 @@ private[ml] object DefaultParamsWriter {
* - uid
* - paramMap
* - (optionally, extra metadata)
+ *
* @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc.
* @param paramMap If given, this is saved in the "paramMap" field.
* Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using
@@ -217,6 +226,22 @@ private[ml] object DefaultParamsWriter {
sc: SparkContext,
extraMetadata: Option[JObject] = None,
paramMap: Option[JValue] = None): Unit = {
+ val metadataPath = new Path(path, "metadata").toString
+ val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap)
+ sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
+ }
+
+ /**
+ * Helper for [[saveMetadata()]] which extracts the JSON to save.
+ * This is useful for ensemble models which need to save metadata for many sub-models.
+ *
+ * @see [[saveMetadata()]] for details on what this includes.
+ */
+ def getMetadataToSave(
+ instance: Params,
+ sc: SparkContext,
+ extraMetadata: Option[JObject] = None,
+ paramMap: Option[JValue] = None): String = {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
@@ -234,9 +259,8 @@ private[ml] object DefaultParamsWriter {
case None =>
basicMetadata
}
- val metadataPath = new Path(path, "metadata").toString
- val metadataJson = compact(render(metadata))
- sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
+ val metadataJson: String = compact(render(metadata))
+ metadataJson
}
}
@@ -244,6 +268,7 @@ private[ml] object DefaultParamsWriter {
* Default [[MLReader]] implementation for transformers and estimators that contain basic
* (json4s-serializable) params and no data. This will not handle more complex params or types with
* data (e.g., models with coefficients).
+ *
* @tparam T ML instance type
* TODO: Consider adding check for correct class name.
*/
@@ -263,6 +288,7 @@ private[ml] object DefaultParamsReader {
/**
* All info from metadata file.
+ *
* @param params paramMap, as a [[JValue]]
* @param metadata All metadata, including the other fields
* @param metadataJson Full metadata file String (for debugging)
@@ -299,13 +325,26 @@ private[ml] object DefaultParamsReader {
}
/**
- * Load metadata from file.
+ * Load metadata saved using [[DefaultParamsWriter.saveMetadata()]]
+ *
* @param expectedClassName If non empty, this is checked against the loaded metadata.
* @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
*/
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
val metadataPath = new Path(path, "metadata").toString
val metadataStr = sc.textFile(metadataPath, 1).first()
+ parseMetadata(metadataStr, expectedClassName)
+ }
+
+ /**
+ * Parse metadata JSON string produced by [[DefaultParamsWriter.getMetadataToSave()]].
+ * This is a helper function for [[loadMetadata()]].
+ *
+ * @param metadataStr JSON string of metadata
+ * @param expectedClassName If non empty, this is checked against the loaded metadata.
+ * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata
+ */
+ def parseMetadata(metadataStr: String, expectedClassName: String = ""): Metadata = {
val metadata = parse(metadataStr)
implicit val format = DefaultFormats
@@ -352,3 +391,36 @@ private[ml] object DefaultParamsReader {
cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
}
}
+
+/**
+ * Default Meta-Algorithm read and write implementation.
+ */
+private[ml] object MetaAlgorithmReadWrite {
+ /**
+ * Examine the given estimator (which may be a compound estimator) and extract a mapping
+ * from UIDs to corresponding [[Params]] instances.
+ */
+ def getUidMap(instance: Params): Map[String, Params] = {
+ val uidList = getUidMapImpl(instance)
+ val uidMap = uidList.toMap
+ if (uidList.size != uidMap.size) {
+ throw new RuntimeException(s"${instance.getClass.getName}.load found a compound estimator" +
+ s" with stages with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}.")
+ }
+ uidMap
+ }
+
+ private def getUidMapImpl(instance: Params): List[(String, Params)] = {
+ val subStages: Array[Params] = instance match {
+ case p: Pipeline => p.getStages.asInstanceOf[Array[Params]]
+ case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]]
+ case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator)
+ case ovr: OneVsRest => Array(ovr.getClassifier)
+ case ovrModel: OneVsRestModel => Array(ovrModel.getClassifier) ++ ovrModel.models
+ case rformModel: RFormulaModel => Array(rformModel.pipelineModel)
+ case _: Params => Array()
+ }
+ val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _)
+ List((instance.uid, instance)) ++ subStageMaps
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
index 76021ad8f4..334410c962 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.util
-import org.apache.spark.sql.types.{DataType, StructField, StructType}
+import org.apache.spark.sql.types.{DataType, NumericType, StructField, StructType}
/**
@@ -44,10 +44,10 @@ private[spark] object SchemaUtils {
}
/**
- * Check whether the given schema contains a column of one of the require data types.
- * @param colName column name
- * @param dataTypes required column data types
- */
+ * Check whether the given schema contains a column of one of the require data types.
+ * @param colName column name
+ * @param dataTypes required column data types
+ */
def checkColumnTypes(
schema: StructType,
colName: String,
@@ -61,6 +61,20 @@ private[spark] object SchemaUtils {
}
/**
+ * Check whether the given schema contains a column of the numeric data type.
+ * @param colName column name
+ */
+ def checkNumericType(
+ schema: StructType,
+ colName: String,
+ msg: String = ""): Unit = {
+ val actualDataType = schema(colName).dataType
+ val message = if (msg != null && msg.trim.length > 0) " " + msg else ""
+ require(actualDataType.isInstanceOf[NumericType], s"Column $colName must be of type " +
+ s"NumericType but was actually of type $actualDataType.$message")
+ }
+
+ /**
* Appends a new column to the input schema. This fails if the given output column already exists.
* @param schema input schema
* @param colName new column name. If this column name is an empty string "", this method returns
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
index a689b09341..364d5eea08 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala
@@ -24,15 +24,15 @@ import org.apache.spark.mllib.clustering.GaussianMixtureModel
import org.apache.spark.mllib.linalg.{Vector, Vectors}
/**
- * Wrapper around GaussianMixtureModel to provide helper methods in Python
- */
+ * Wrapper around GaussianMixtureModel to provide helper methods in Python
+ */
private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) {
val weights: Vector = Vectors.dense(model.weights)
val k: Int = weights.size
/**
- * Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian
- */
+ * Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian
+ */
val gaussians: Array[Byte] = {
val modelGaussians = model.gaussians.map { gaussian =>
Array[Any](gaussian.mu, gaussian.sigma)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala
index 073f03e16f..05273c3434 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala
@@ -27,8 +27,8 @@ import org.apache.spark.mllib.feature.Word2VecModel
import org.apache.spark.mllib.linalg.{Vector, Vectors}
/**
- * Wrapper around Word2VecModel to provide helper methods in Python
- */
+ * Wrapper around Word2VecModel to provide helper methods in Python
+ */
private[python] class Word2VecModelWrapper(model: Word2VecModel) {
def transform(word: String): Vector = {
model.transform(word)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index c0404be019..f10570e662 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -418,7 +418,7 @@ class LogisticRegressionWithLBFGS
private def run(input: RDD[LabeledPoint], initialWeights: Vector, userSuppliedWeights: Boolean):
LogisticRegressionModel = {
- // ml's Logisitic regression only supports binary classifcation currently.
+ // ml's Logistic regression only supports binary classification currently.
if (numOfLinearPredictor == 1) {
def runWithMlLogisitcRegression(elasticNetParam: Double) = {
// Prepare the ml LogisticRegression based on our settings
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
index 64b838a1db..e4bd0dc25e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
@@ -411,7 +411,7 @@ private object BisectingKMeans extends Serializable {
private[clustering] class ClusteringTreeNode private[clustering] (
val index: Int,
val size: Long,
- private val centerWithNorm: VectorWithNorm,
+ private[clustering] val centerWithNorm: VectorWithNorm,
val cost: Double,
val height: Double,
val children: Array[ClusteringTreeNode]) extends Serializable {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
index 01a0d31f14..c3b5b8b790 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
@@ -17,11 +17,19 @@
package org.apache.spark.mllib.clustering
+import org.json4s._
+import org.json4s.DefaultFormats
+import org.json4s.jackson.JsonMethods._
+import org.json4s.JsonDSL._
+
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext}
/**
* Clustering model produced by [[BisectingKMeans]].
@@ -34,7 +42,7 @@ import org.apache.spark.rdd.RDD
@Experimental
class BisectingKMeansModel private[clustering] (
private[clustering] val root: ClusteringTreeNode
- ) extends Serializable with Logging {
+ ) extends Serializable with Saveable with Logging {
/**
* Leaf cluster centers.
@@ -92,4 +100,92 @@ class BisectingKMeansModel private[clustering] (
*/
@Since("1.6.0")
def computeCost(data: JavaRDD[Vector]): Double = this.computeCost(data.rdd)
+
+ @Since("2.0.0")
+ override def save(sc: SparkContext, path: String): Unit = {
+ BisectingKMeansModel.SaveLoadV1_0.save(sc, this, path)
+ }
+
+ override protected def formatVersion: String = "1.0"
+}
+
+@Since("2.0.0")
+object BisectingKMeansModel extends Loader[BisectingKMeansModel] {
+
+ @Since("2.0.0")
+ override def load(sc: SparkContext, path: String): BisectingKMeansModel = {
+ val (loadedClassName, formatVersion, metadata) = Loader.loadMetadata(sc, path)
+ implicit val formats = DefaultFormats
+ val rootId = (metadata \ "rootId").extract[Int]
+ val classNameV1_0 = SaveLoadV1_0.thisClassName
+ (loadedClassName, formatVersion) match {
+ case (classNameV1_0, "1.0") =>
+ val model = SaveLoadV1_0.load(sc, path, rootId)
+ model
+ case _ => throw new Exception(
+ s"BisectingKMeansModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $formatVersion). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
+
+ private case class Data(index: Int, size: Long, center: Vector, norm: Double, cost: Double,
+ height: Double, children: Seq[Int])
+
+ private object Data {
+ def apply(r: Row): Data = Data(r.getInt(0), r.getLong(1), r.getAs[Vector](2), r.getDouble(3),
+ r.getDouble(4), r.getDouble(5), r.getSeq[Int](6))
+ }
+
+ private[clustering] object SaveLoadV1_0 {
+ private val thisFormatVersion = "1.0"
+
+ private[clustering]
+ val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel"
+
+ def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = {
+ val sqlContext = SQLContext.getOrCreate(sc)
+ import sqlContext.implicits._
+ val metadata = compact(render(
+ ("class" -> thisClassName) ~ ("version" -> thisFormatVersion)
+ ~ ("rootId" -> model.root.index)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+ val data = getNodes(model.root).map(node => Data(node.index, node.size,
+ node.centerWithNorm.vector, node.centerWithNorm.norm, node.cost, node.height,
+ node.children.map(_.index)))
+ val dataRDD = sc.parallelize(data).toDF()
+ dataRDD.write.parquet(Loader.dataPath(path))
+ }
+
+ private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = {
+ if (node.children.isEmpty) {
+ Array(node)
+ } else {
+ node.children.flatMap(getNodes(_)) ++ Array(node)
+ }
+ }
+
+ def load(sc: SparkContext, path: String, rootId: Int): BisectingKMeansModel = {
+ val sqlContext = SQLContext.getOrCreate(sc)
+ val rows = sqlContext.read.parquet(Loader.dataPath(path))
+ Loader.checkSchema[Data](rows.schema)
+ val data = rows.select("index", "size", "center", "norm", "cost", "height", "children")
+ val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap
+ val rootNode = buildTree(rootId, nodes)
+ new BisectingKMeansModel(rootNode)
+ }
+
+ private def buildTree(rootId: Int, nodes: Map[Int, Data]): ClusteringTreeNode = {
+ val root = nodes.get(rootId).get
+ if (root.children.isEmpty) {
+ new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm),
+ root.cost, root.height, new Array[ClusteringTreeNode](0))
+ } else {
+ val children = root.children.map(c => buildTree(c, nodes))
+ new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm),
+ root.cost, root.height, children.toArray)
+ }
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
index 03eb903bb8..f04c87259c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
@@ -181,13 +181,12 @@ class GaussianMixture private (
val (weights, gaussians) = initialModel match {
case Some(gmm) => (gmm.weights, gmm.gaussians)
- case None => {
+ case None =>
val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed)
(Array.fill(k)(1.0 / k), Array.tabulate(k) { i =>
val slice = samples.view(i * nSamples, (i + 1) * nSamples)
new MultivariateGaussian(vectorMean(slice), initCovariance(slice))
})
- }
}
var llh = Double.MinValue // current log-likelihood
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index 02417b1124..f87613cc72 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -183,7 +183,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
val k = (metadata \ "k").extract[Int]
val classNameV1_0 = SaveLoadV1_0.classNameV1_0
(loadedClassName, version) match {
- case (classNameV1_0, "1.0") => {
+ case (classNameV1_0, "1.0") =>
val model = SaveLoadV1_0.load(sc, path)
require(model.weights.length == k,
s"GaussianMixtureModel requires weights of length $k " +
@@ -192,7 +192,6 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
s"GaussianMixtureModel requires gaussians of length $k" +
s"got gaussians of length ${model.gaussians.length}")
model
- }
case _ => throw new Exception(
s"GaussianMixtureModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index a7beb81980..8ff0b83e8b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -253,16 +253,14 @@ class KMeans private (
}
val centers = initialModel match {
- case Some(kMeansCenters) => {
+ case Some(kMeansCenters) =>
Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s)))
- }
- case None => {
+ case None =>
if (initializationMode == KMeans.RANDOM) {
initRandom(data)
} else {
initKMeansParallel(data)
}
- }
}
val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) +
@@ -390,6 +388,8 @@ class KMeans private (
// Initialize each run's first center to a random point.
val seed = new XORShiftRandom(this.seed).nextInt()
val sample = data.takeSample(true, runs, seed).toSeq
+ // Could be empty if data is empty; fail with a better message early:
+ require(sample.size >= runs, s"Required $runs samples but got ${sample.size} from $data")
val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
/** Merges new centers to centers. */
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
index 12813fd412..d999b9be8e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
@@ -130,7 +130,8 @@ class LDA private (
*/
@Since("1.5.0")
def setDocConcentration(docConcentration: Vector): this.type = {
- require(docConcentration.size > 0, "docConcentration must have > 0 elements")
+ require(docConcentration.size == 1 || docConcentration.size == k,
+ s"Size of docConcentration must be 1 or ${k} but got ${docConcentration.size}")
this.docConcentration = docConcentration
this
}
@@ -260,15 +261,18 @@ class LDA private (
def getCheckpointInterval: Int = checkpointInterval
/**
- * Period (in iterations) between checkpoints (default = 10). Checkpointing helps with recovery
+ * Parameter for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that
+ * the cache will get checkpointed every 10 iterations. Checkpointing helps with recovery
* (when nodes fail). It also helps with eliminating temporary shuffle files on disk, which can be
* important when LDA is run for many iterations. If the checkpoint directory is not set in
- * [[org.apache.spark.SparkContext]], this setting is ignored.
+ * [[org.apache.spark.SparkContext]], this setting is ignored. (default = 10)
*
* @see [[org.apache.spark.SparkContext#setCheckpointDir]]
*/
@Since("1.3.0")
def setCheckpointInterval(checkpointInterval: Int): this.type = {
+ require(checkpointInterval == -1 || checkpointInterval > 0,
+ s"Period between checkpoints must be -1 or positive but got ${checkpointInterval}")
this.checkpointInterval = checkpointInterval
this
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 25d67a3756..27b4004927 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -534,7 +534,8 @@ class DistributedLDAModel private[clustering] (
@Since("1.5.0") override val docConcentration: Vector,
@Since("1.5.0") override val topicConcentration: Double,
private[spark] val iterationTimes: Array[Double],
- override protected[clustering] val gammaShape: Double = 100)
+ override protected[clustering] val gammaShape: Double = DistributedLDAModel.defaultGammaShape,
+ private[spark] val checkpointFiles: Array[String] = Array.empty[String])
extends LDAModel {
import LDA._
@@ -806,11 +807,9 @@ class DistributedLDAModel private[clustering] (
override protected def formatVersion = "1.0"
- /**
- * Java-friendly version of [[topicDistributions]]
- */
@Since("1.5.0")
override def save(sc: SparkContext, path: String): Unit = {
+ // Note: This intentionally does not save checkpointFiles.
DistributedLDAModel.SaveLoadV1_0.save(
sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration,
iterationTimes, gammaShape)
@@ -822,6 +821,12 @@ class DistributedLDAModel private[clustering] (
@Since("1.5.0")
object DistributedLDAModel extends Loader[DistributedLDAModel] {
+ /**
+ * The [[DistributedLDAModel]] constructor's default arguments assume gammaShape = 100
+ * to ensure equivalence in LDAModel.toLocal conversion.
+ */
+ private[clustering] val defaultGammaShape: Double = 100
+
private object SaveLoadV1_0 {
val thisFormatVersion = "1.0"
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index 7491ab0d51..6418f0d3b3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -80,9 +80,29 @@ final class EMLDAOptimizer extends LDAOptimizer {
import LDA._
+ // Adjustable parameters
+ private var keepLastCheckpoint: Boolean = true
+
/**
- * The following fields will only be initialized through the initialize() method
+ * If using checkpointing, this indicates whether to keep the last checkpoint (vs clean up).
+ */
+ @Since("2.0.0")
+ def getKeepLastCheckpoint: Boolean = this.keepLastCheckpoint
+
+ /**
+ * If using checkpointing, this indicates whether to keep the last checkpoint (vs clean up).
+ * Deleting the checkpoint can cause failures if a data partition is lost, so set this bit with
+ * care. Note that checkpoints will be cleaned up via reference counting, regardless.
+ *
+ * Default: true
*/
+ @Since("2.0.0")
+ def setKeepLastCheckpoint(keepLastCheckpoint: Boolean): this.type = {
+ this.keepLastCheckpoint = keepLastCheckpoint
+ this
+ }
+
+ // The following fields will only be initialized through the initialize() method
private[clustering] var graph: Graph[TopicCounts, TokenCount] = null
private[clustering] var k: Int = 0
private[clustering] var vocabSize: Int = 0
@@ -208,12 +228,18 @@ final class EMLDAOptimizer extends LDAOptimizer {
override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
require(graph != null, "graph is null, EMLDAOptimizer not initialized.")
- this.graphCheckpointer.deleteAllCheckpoints()
+ val checkpointFiles: Array[String] = if (keepLastCheckpoint) {
+ this.graphCheckpointer.deleteAllCheckpointsButLast()
+ this.graphCheckpointer.getAllCheckpointFiles
+ } else {
+ this.graphCheckpointer.deleteAllCheckpoints()
+ Array.empty[String]
+ }
// The constructor's default arguments assume gammaShape = 100 to ensure equivalence in
- // LDAModel.toLocal conversion
+ // LDAModel.toLocal conversion.
new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize,
Vectors.dense(Array.fill(this.k)(this.docConcentration)), this.topicConcentration,
- iterationTimes)
+ iterationTimes, DistributedLDAModel.defaultGammaShape, checkpointFiles)
}
}
@@ -451,10 +477,11 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
}
Iterator((stat, gammaPart))
}
- val statsSum: BDM[Double] = stats.map(_._1).reduce(_ += _)
+ val statsSum: BDM[Double] = stats.map(_._1).treeAggregate(BDM.zeros[Double](k, vocabSize))(
+ _ += _, _ += _)
expElogbetaBc.unpersist()
val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat(
- stats.map(_._2).reduce(_ ++ _).map(_.toDenseMatrix): _*)
+ stats.map(_._2).flatMap(list => list).collect().map(_.toDenseMatrix): _*)
val batchResult = statsSum :* expElogbeta.t
// Note that this is an optimization to avoid batch.count
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
index 4eb8fc049e..24e1cff0dc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
@@ -218,6 +218,12 @@ class StreamingKMeans @Since("1.2.0") (
*/
@Since("1.2.0")
def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = {
+ require(centers.size == weights.size,
+ "Number of initial centers must be equal to number of weights")
+ require(centers.size == k,
+ s"Number of initial centers must be ${k} but got ${centers.size}")
+ require(weights.forall(_ >= 0),
+ s"Weight for each inital center must be nonnegative but got [${weights.mkString(" ")}]")
model = new StreamingKMeansModel(centers, weights)
this
}
@@ -231,6 +237,10 @@ class StreamingKMeans @Since("1.2.0") (
*/
@Since("1.2.0")
def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = {
+ require(dim > 0,
+ s"Number of dimensions must be positive but got ${dim}")
+ require(weight >= 0,
+ s"Weight for each center must be nonnegative but got ${weight}")
val random = new XORShiftRandom(seed)
val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian())))
val weights = Array.fill(k)(weight)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
index c93ed64183..47c9e850a0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
@@ -36,12 +36,24 @@ import org.apache.spark.util.Utils
@Since("1.1.0")
class HashingTF(val numFeatures: Int) extends Serializable {
+ private var binary = false
+
/**
*/
@Since("1.1.0")
def this() = this(1 << 20)
/**
+ * If true, term frequency vector will be binary such that non-zero term counts will be set to 1
+ * (default: false)
+ */
+ @Since("2.0.0")
+ def setBinary(value: Boolean): this.type = {
+ binary = value
+ this
+ }
+
+ /**
* Returns the index of the input term.
*/
@Since("1.1.0")
@@ -53,9 +65,10 @@ class HashingTF(val numFeatures: Int) extends Serializable {
@Since("1.1.0")
def transform(document: Iterable[_]): Vector = {
val termFrequencies = mutable.HashMap.empty[Int, Double]
+ val setTF = if (binary) (i: Int) => 1.0 else (i: Int) => termFrequencies.getOrElse(i, 0.0) + 1.0
document.foreach { term =>
val i = indexOf(term)
- termFrequencies.put(i, termFrequencies.getOrElse(i, 0.0) + 1.0)
+ termFrequencies.put(i, setTF(i))
}
Vectors.sparse(numFeatures, termFrequencies.toSeq)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
index 4455681e50..4344ab1bad 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
@@ -23,12 +23,22 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.reflect.ClassTag
+import scala.reflect.runtime.universe._
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render}
+
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.internal.Logging
+import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
/**
@@ -566,4 +576,88 @@ object PrefixSpan extends Logging {
@Since("1.5.0")
class PrefixSpanModel[Item] @Since("1.5.0") (
@Since("1.5.0") val freqSequences: RDD[PrefixSpan.FreqSequence[Item]])
- extends Serializable
+ extends Saveable with Serializable {
+
+ /**
+ * Save this model to the given path.
+ * It only works for Item datatypes supported by DataFrames.
+ *
+ * This saves:
+ * - human-readable (JSON) model metadata to path/metadata/
+ * - Parquet formatted data to path/data/
+ *
+ * The model may be loaded using [[PrefixSpanModel.load]].
+ *
+ * @param sc Spark context used to save model data.
+ * @param path Path specifying the directory in which to save this model.
+ * If the directory already exists, this method throws an exception.
+ */
+ @Since("2.0.0")
+ override def save(sc: SparkContext, path: String): Unit = {
+ PrefixSpanModel.SaveLoadV1_0.save(this, path)
+ }
+
+ override protected val formatVersion: String = "1.0"
+}
+
+@Since("2.0.0")
+object PrefixSpanModel extends Loader[PrefixSpanModel[_]] {
+
+ @Since("2.0.0")
+ override def load(sc: SparkContext, path: String): PrefixSpanModel[_] = {
+ PrefixSpanModel.SaveLoadV1_0.load(sc, path)
+ }
+
+ private[fpm] object SaveLoadV1_0 {
+
+ private val thisFormatVersion = "1.0"
+
+ private val thisClassName = "org.apache.spark.mllib.fpm.PrefixSpanModel"
+
+ def save(model: PrefixSpanModel[_], path: String): Unit = {
+ val sc = model.freqSequences.sparkContext
+ val sqlContext = SQLContext.getOrCreate(sc)
+
+ val metadata = compact(render(
+ ("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+ // Get the type of item class
+ val sample = model.freqSequences.first().sequence(0)(0)
+ val className = sample.getClass.getCanonicalName
+ val classSymbol = runtimeMirror(getClass.getClassLoader).staticClass(className)
+ val tpe = classSymbol.selfType
+
+ val itemType = ScalaReflection.schemaFor(tpe).dataType
+ val fields = Array(StructField("sequence", ArrayType(ArrayType(itemType))),
+ StructField("freq", LongType))
+ val schema = StructType(fields)
+ val rowDataRDD = model.freqSequences.map { x =>
+ Row(x.sequence, x.freq)
+ }
+ sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path))
+ }
+
+ def load(sc: SparkContext, path: String): PrefixSpanModel[_] = {
+ implicit val formats = DefaultFormats
+ val sqlContext = SQLContext.getOrCreate(sc)
+
+ val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
+ assert(className == thisClassName)
+ assert(formatVersion == thisFormatVersion)
+
+ val freqSequences = sqlContext.read.parquet(Loader.dataPath(path))
+ val sample = freqSequences.select("sequence").head().get(0)
+ loadImpl(freqSequences, sample)
+ }
+
+ def loadImpl[Item: ClassTag](freqSequences: DataFrame, sample: Item): PrefixSpanModel[Item] = {
+ val freqSequencesRDD = freqSequences.select("sequence", "freq").rdd.map { x =>
+ val sequence = x.getAs[Seq[Seq[Item]]](0).map(_.toArray).toArray
+ val freq = x.getLong(1)
+ new PrefixSpan.FreqSequence(sequence, freq)
+ }
+ new PrefixSpanModel(freqSequencesRDD)
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
index 391f89aa14..5c12c9305b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
@@ -52,7 +52,8 @@ import org.apache.spark.storage.StorageLevel
* - This class removes checkpoint files once later Datasets have been checkpointed.
* However, references to the older Datasets will still return isCheckpointed = true.
*
- * @param checkpointInterval Datasets will be checkpointed at this interval
+ * @param checkpointInterval Datasets will be checkpointed at this interval.
+ * If this interval was set as -1, then checkpointing will be disabled.
* @param sc SparkContext for the Datasets given to this checkpointer
* @tparam T Dataset type, such as RDD[Double]
*/
@@ -89,7 +90,8 @@ private[mllib] abstract class PeriodicCheckpointer[T](
updateCount += 1
// Handle checkpointing (after persisting)
- if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) {
+ if (checkpointInterval != -1 && (updateCount % checkpointInterval) == 0
+ && sc.getCheckpointDir.nonEmpty) {
// Add new checkpoint before removing old checkpoints.
checkpoint(newData)
checkpointQueue.enqueue(newData)
@@ -134,6 +136,24 @@ private[mllib] abstract class PeriodicCheckpointer[T](
}
/**
+ * Call this at the end to delete any remaining checkpoint files, except for the last checkpoint.
+ * Note that there may not be any checkpoints at all.
+ */
+ def deleteAllCheckpointsButLast(): Unit = {
+ while (checkpointQueue.size > 1) {
+ removeCheckpointFile()
+ }
+ }
+
+ /**
+ * Get all current checkpoint files.
+ * This is useful in combination with [[deleteAllCheckpointsButLast()]].
+ */
+ def getAllCheckpointFiles: Array[String] = {
+ checkpointQueue.flatMap(getCheckpointFiles).toArray
+ }
+
+ /**
* Dequeue the oldest checkpointed Dataset, and remove its checkpoint files.
* This prints a warning but does not fail if the files cannot be removed.
*/
@@ -141,15 +161,20 @@ private[mllib] abstract class PeriodicCheckpointer[T](
val old = checkpointQueue.dequeue()
// Since the old checkpoint is not deleted by Spark, we manually delete it.
val fs = FileSystem.get(sc.hadoopConfiguration)
- getCheckpointFiles(old).foreach { checkpointFile =>
- try {
- fs.delete(new Path(checkpointFile), true)
- } catch {
- case e: Exception =>
- logWarning("PeriodicCheckpointer could not remove old checkpoint file: " +
- checkpointFile)
- }
- }
+ getCheckpointFiles(old).foreach(PeriodicCheckpointer.removeCheckpointFile(_, fs))
}
+}
+
+private[spark] object PeriodicCheckpointer extends Logging {
+ /** Delete a checkpoint file, and log a warning if deletion fails. */
+ def removeCheckpointFile(path: String, fs: FileSystem): Unit = {
+ try {
+ fs.delete(new Path(path), true)
+ } catch {
+ case e: Exception =>
+ logWarning("PeriodicCheckpointer could not remove old checkpoint file: " +
+ path)
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
index 11a059536c..20db6084d0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
@@ -69,7 +69,8 @@ import org.apache.spark.storage.StorageLevel
* // checkpointed: graph4
* }}}
*
- * @param checkpointInterval Graphs will be checkpointed at this interval
+ * @param checkpointInterval Graphs will be checkpointed at this interval.
+ * If this interval was set as -1, then checkpointing will be disabled.
* @tparam VD Vertex descriptor type
* @tparam ED Edge descriptor type
*
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index c6de7751f5..8c09b69b3c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -123,14 +123,18 @@ sealed trait Matrix extends Serializable {
@Since("1.4.0")
def toString(maxLines: Int, maxLineWidth: Int): String = toBreeze.toString(maxLines, maxLineWidth)
- /** Map the values of this matrix using a function. Generates a new matrix. Performs the
- * function on only the backing array. For example, an operation such as addition or
- * subtraction will only be performed on the non-zero values in a `SparseMatrix`. */
+ /**
+ * Map the values of this matrix using a function. Generates a new matrix. Performs the
+ * function on only the backing array. For example, an operation such as addition or
+ * subtraction will only be performed on the non-zero values in a `SparseMatrix`.
+ */
private[spark] def map(f: Double => Double): Matrix
- /** Update all the values of this matrix using the function f. Performed in-place on the
- * backing array. For example, an operation such as addition or subtraction will only be
- * performed on the non-zero values in a `SparseMatrix`. */
+ /**
+ * Update all the values of this matrix using the function f. Performed in-place on the
+ * backing array. For example, an operation such as addition or subtraction will only be
+ * performed on the non-zero values in a `SparseMatrix`.
+ */
private[mllib] def update(f: Double => Double): Matrix
/**
@@ -613,7 +617,7 @@ class SparseMatrix @Since("1.3.0") (
private[mllib] def update(i: Int, j: Int, v: Double): Unit = {
val ind = index(i, j)
- if (ind == -1) {
+ if (ind < 0) {
throw new NoSuchElementException("The given row and column indices correspond to a zero " +
"value. Only non-zero elements in Sparse Matrices can be updated.")
} else {
@@ -940,8 +944,16 @@ object Matrices {
case dm: BDM[Double] =>
new DenseMatrix(dm.rows, dm.cols, dm.data, dm.isTranspose)
case sm: BSM[Double] =>
+ // Spark-11507. work around breeze issue 479.
+ val mat = if (sm.colPtrs.last != sm.data.length) {
+ val matCopy = sm.copy
+ matCopy.compact()
+ matCopy
+ } else {
+ sm
+ }
// There is no isTranspose flag for sparse matrices in Breeze
- new SparseMatrix(sm.rows, sm.cols, sm.colPtrs, sm.rowIndices, sm.data)
+ new SparseMatrix(mat.rows, mat.cols, mat.colPtrs, mat.rowIndices, mat.data)
case _ =>
throw new UnsupportedOperationException(
s"Do not support conversion from type ${breeze.getClass.getName}.")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 0f0c3a2df5..5812cdde2c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -186,7 +186,7 @@ sealed trait Vector extends Serializable {
* :: AlphaComponent ::
*
* User-defined type for [[Vector]] which allows easy interaction with SQL
- * via [[org.apache.spark.sql.DataFrame]].
+ * via [[org.apache.spark.sql.Dataset]].
*/
@AlphaComponent
class VectorUDT extends UserDefinedType[Vector] {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
index e8f4422fd4..84764963b5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
@@ -81,8 +81,8 @@ class StreamingLinearRegressionWithSGD private[mllib] (
}
/**
- * Set the number of iterations of gradient descent to run per update. Default: 50.
- */
+ * Set the number of iterations of gradient descent to run per update. Default: 50.
+ */
@Since("1.1.0")
def setNumIterations(numIterations: Int): this.type = {
this.algorithm.optimizer.setNumIterations(numIterations)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
index 052b5b1d65..6c6e9fb7c6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
@@ -61,15 +61,17 @@ class MultivariateGaussian @Since("1.3.0") (
*/
private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants
- /** Returns density of this multivariate Gaussian at given point, x
- */
+ /**
+ * Returns density of this multivariate Gaussian at given point, x
+ */
@Since("1.3.0")
def pdf(x: Vector): Double = {
pdf(x.toBreeze)
}
- /** Returns the log-density of this multivariate Gaussian at given point, x
- */
+ /**
+ * Returns the log-density of this multivariate Gaussian at given point, x
+ */
@Since("1.3.0")
def logpdf(x: Vector): Double = {
logpdf(x.toBreeze)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala
index baf9e5e7d1..9748fbf2c9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala
@@ -166,7 +166,7 @@ private[stat] object KolmogorovSmirnovTest extends Logging {
: KolmogorovSmirnovTestResult = {
val distObj =
distName match {
- case "norm" => {
+ case "norm" =>
if (params.nonEmpty) {
// parameters are passed, then can only be 2
require(params.length == 2, "Normal distribution requires mean and standard " +
@@ -178,7 +178,6 @@ private[stat] object KolmogorovSmirnovTest extends Logging {
"initialized to standard normal (i.e. N(0, 1))")
new NormalDistribution(0, 1)
}
- }
case _ => throw new UnsupportedOperationException(s"$distName not yet supported through" +
s" convenience method. Current options are:['norm'].")
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index d166dc7905..7fe60e2d99 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -20,15 +20,11 @@ package org.apache.spark.mllib.tree
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
-import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
+import org.apache.spark.ml.tree.impl.{GradientBoostedTrees => NewGBT}
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
-import org.apache.spark.mllib.tree.impl.TimeTracker
-import org.apache.spark.mllib.tree.impurity.Variance
-import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel}
+import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.StorageLevel
/**
* A class that implements
@@ -70,17 +66,8 @@ class GradientBoostedTrees private[spark] (
@Since("1.2.0")
def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
val algo = boostingStrategy.treeStrategy.algo
- algo match {
- case Regression =>
- GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed)
- case Classification =>
- // Map labels to -1, +1 so binary classification can be treated as regression.
- val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
- GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false,
- seed)
- case _ =>
- throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
- }
+ val (trees, treeWeights) = NewGBT.run(input, boostingStrategy, seed.toLong)
+ new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights)
}
/**
@@ -107,20 +94,9 @@ class GradientBoostedTrees private[spark] (
input: RDD[LabeledPoint],
validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = {
val algo = boostingStrategy.treeStrategy.algo
- algo match {
- case Regression =>
- GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed)
- case Classification =>
- // Map labels to -1, +1 so binary classification can be treated as regression.
- val remappedInput = input.map(
- x => new LabeledPoint((x.label * 2) - 1, x.features))
- val remappedValidationInput = validationInput.map(
- x => new LabeledPoint((x.label * 2) - 1, x.features))
- GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
- validate = true, seed)
- case _ =>
- throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
- }
+ val (trees, treeWeights) = NewGBT.runWithValidation(input, validationInput, boostingStrategy,
+ seed.toLong)
+ new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights)
}
/**
@@ -162,147 +138,4 @@ object GradientBoostedTrees extends Logging {
boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
train(input.rdd, boostingStrategy)
}
-
- /**
- * Internal method for performing regression using trees as base learners.
- * @param input Training dataset.
- * @param validationInput Validation dataset, ignored if validate is set to false.
- * @param boostingStrategy Boosting parameters.
- * @param validate Whether or not to use the validation dataset.
- * @param seed Random seed.
- * @return GradientBoostedTreesModel that can be used for prediction.
- */
- private def boost(
- input: RDD[LabeledPoint],
- validationInput: RDD[LabeledPoint],
- boostingStrategy: BoostingStrategy,
- validate: Boolean,
- seed: Int): GradientBoostedTreesModel = {
- val timer = new TimeTracker()
- timer.start("total")
- timer.start("init")
-
- boostingStrategy.assertValid()
-
- // Initialize gradient boosting parameters
- val numIterations = boostingStrategy.numIterations
- val baseLearners = new Array[DecisionTreeModel](numIterations)
- val baseLearnerWeights = new Array[Double](numIterations)
- val loss = boostingStrategy.loss
- val learningRate = boostingStrategy.learningRate
- // Prepare strategy for individual trees, which use regression with variance impurity.
- val treeStrategy = boostingStrategy.treeStrategy.copy
- val validationTol = boostingStrategy.validationTol
- treeStrategy.algo = Regression
- treeStrategy.impurity = Variance
- treeStrategy.assertValid()
-
- // Cache input
- val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) {
- input.persist(StorageLevel.MEMORY_AND_DISK)
- true
- } else {
- false
- }
-
- // Prepare periodic checkpointers
- val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
- treeStrategy.getCheckpointInterval, input.sparkContext)
- val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
- treeStrategy.getCheckpointInterval, input.sparkContext)
-
- timer.stop("init")
-
- logDebug("##########")
- logDebug("Building tree 0")
- logDebug("##########")
-
- // Initialize tree
- timer.start("building tree 0")
- val firstTreeModel = new DecisionTree(treeStrategy, seed).run(input)
- val firstTreeWeight = 1.0
- baseLearners(0) = firstTreeModel
- baseLearnerWeights(0) = firstTreeWeight
-
- var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
- computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
- predErrorCheckpointer.update(predError)
- logDebug("error of gbt = " + predError.values.mean())
-
- // Note: A model of type regression is used since we require raw prediction
- timer.stop("building tree 0")
-
- var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel.
- computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
- if (validate) validatePredErrorCheckpointer.update(validatePredError)
- var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
- var bestM = 1
-
- var m = 1
- var doneLearning = false
- while (m < numIterations && !doneLearning) {
- // Update data with pseudo-residuals
- val data = predError.zip(input).map { case ((pred, _), point) =>
- LabeledPoint(-loss.gradient(pred, point.label), point.features)
- }
-
- timer.start(s"building tree $m")
- logDebug("###################################################")
- logDebug("Gradient boosting tree iteration " + m)
- logDebug("###################################################")
- val model = new DecisionTree(treeStrategy, seed + m).run(data)
- timer.stop(s"building tree $m")
- // Update partial model
- baseLearners(m) = model
- // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
- // Technically, the weight should be optimized for the particular loss.
- // However, the behavior should be reasonable, though not optimal.
- baseLearnerWeights(m) = learningRate
-
- predError = GradientBoostedTreesModel.updatePredictionError(
- input, predError, baseLearnerWeights(m), baseLearners(m), loss)
- predErrorCheckpointer.update(predError)
- logDebug("error of gbt = " + predError.values.mean())
-
- if (validate) {
- // Stop training early if
- // 1. Reduction in error is less than the validationTol or
- // 2. If the error increases, that is if the model is overfit.
- // We want the model returned corresponding to the best validation error.
-
- validatePredError = GradientBoostedTreesModel.updatePredictionError(
- validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
- validatePredErrorCheckpointer.update(validatePredError)
- val currentValidateError = validatePredError.values.mean()
- if (bestValidateError - currentValidateError < validationTol * Math.max(
- currentValidateError, 0.01)) {
- doneLearning = true
- } else if (currentValidateError < bestValidateError) {
- bestValidateError = currentValidateError
- bestM = m + 1
- }
- }
- m += 1
- }
-
- timer.stop("total")
-
- logInfo("Internal timing for DecisionTree:")
- logInfo(s"$timer")
-
- predErrorCheckpointer.deleteAllCheckpoints()
- validatePredErrorCheckpointer.deleteAllCheckpoints()
- if (persistedInput) input.unpersist()
-
- if (validate) {
- new GradientBoostedTreesModel(
- boostingStrategy.treeStrategy.algo,
- baseLearners.slice(0, bestM),
- baseLearnerWeights.slice(0, bestM))
- } else {
- new GradientBoostedTreesModel(
- boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
- }
- }
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 1841fa4a95..26755849ad 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -55,10 +55,15 @@ import org.apache.spark.util.Utils
* @param numTrees If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
* @param featureSubsetStrategy Number of features to consider for splits at each node.
* Supported values: "auto", "all", "sqrt", "log2", "onethird".
+ * Supported numerical values: "(0.0-1.0]", "[1-n]".
* If "auto" is set, this parameter is set based on numTrees:
* if numTrees == 1, set to "all";
* if numTrees > 1 (forest) set to "sqrt" for classification and
* to "onethird" for regression.
+ * If a real value "n" in the range (0, 1.0] is set,
+ * use n * number of features.
+ * If an integer value "n" in the range (1, num features) is set,
+ * use n features.
* @param seed Random seed for bootstrapping and choosing feature subsets.
*/
private class RandomForest (
@@ -70,9 +75,11 @@ private class RandomForest (
strategy.assertValid()
require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.")
- require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy),
+ require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy)
+ || featureSubsetStrategy.matches(NewRFParams.supportedFeatureSubsetStrategiesRegex),
s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." +
- s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}.")
+ s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}," +
+ s" (0.0-1.0], [1-n].")
/**
* Method to train a decision tree model over an RDD
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
deleted file mode 100644
index dc7e969f7b..0000000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
+++ /dev/null
@@ -1,195 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.impl
-
-import scala.collection.mutable
-
-import org.apache.hadoop.fs.{FileSystem, Path}
-
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.mllib.tree.configuration.FeatureType._
-import org.apache.spark.mllib.tree.model.{Bin, Node, Split}
-import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.StorageLevel
-
-/**
- * :: DeveloperApi ::
- * This is used by the node id cache to find the child id that a data point would belong to.
- * @param split Split information.
- * @param nodeIndex The current node index of a data point that this will update.
- */
-@DeveloperApi
-private[tree] case class NodeIndexUpdater(
- split: Split,
- nodeIndex: Int) {
- /**
- * Determine a child node index based on the feature value and the split.
- * @param binnedFeatures Binned feature values.
- * @param bins Bin information to convert the bin indices to approximate feature values.
- * @return Child node index to update to.
- */
- def updateNodeIndex(binnedFeatures: Array[Int], bins: Array[Array[Bin]]): Int = {
- if (split.featureType == Continuous) {
- val featureIndex = split.feature
- val binIndex = binnedFeatures(featureIndex)
- val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
- if (featureValueUpperBound <= split.threshold) {
- Node.leftChildIndex(nodeIndex)
- } else {
- Node.rightChildIndex(nodeIndex)
- }
- } else {
- if (split.categories.contains(binnedFeatures(split.feature).toDouble)) {
- Node.leftChildIndex(nodeIndex)
- } else {
- Node.rightChildIndex(nodeIndex)
- }
- }
- }
-}
-
-/**
- * :: DeveloperApi ::
- * A given TreePoint would belong to a particular node per tree.
- * Each row in the nodeIdsForInstances RDD is an array over trees of the node index
- * in each tree. Initially, values should all be 1 for root node.
- * The nodeIdsForInstances RDD needs to be updated at each iteration.
- * @param nodeIdsForInstances The initial values in the cache
- * (should be an Array of all 1's (meaning the root nodes)).
- * @param checkpointInterval The checkpointing interval
- * (how often should the cache be checkpointed.).
- */
-@DeveloperApi
-private[spark] class NodeIdCache(
- var nodeIdsForInstances: RDD[Array[Int]],
- val checkpointInterval: Int) {
-
- // Keep a reference to a previous node Ids for instances.
- // Because we will keep on re-persisting updated node Ids,
- // we want to unpersist the previous RDD.
- private var prevNodeIdsForInstances: RDD[Array[Int]] = null
-
- // To keep track of the past checkpointed RDDs.
- private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
- private var rddUpdateCount = 0
-
- /**
- * Update the node index values in the cache.
- * This updates the RDD and its lineage.
- * TODO: Passing bin information to executors seems unnecessary and costly.
- * @param data The RDD of training rows.
- * @param nodeIdUpdaters A map of node index updaters.
- * The key is the indices of nodes that we want to update.
- * @param bins Bin information needed to find child node indices.
- */
- def updateNodeIndices(
- data: RDD[BaggedPoint[TreePoint]],
- nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]],
- bins: Array[Array[Bin]]): Unit = {
- if (prevNodeIdsForInstances != null) {
- // Unpersist the previous one if one exists.
- prevNodeIdsForInstances.unpersist()
- }
-
- prevNodeIdsForInstances = nodeIdsForInstances
- nodeIdsForInstances = data.zip(nodeIdsForInstances).map {
- case (point, node) => {
- var treeId = 0
- while (treeId < nodeIdUpdaters.length) {
- val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(node(treeId), null)
- if (nodeIdUpdater != null) {
- val newNodeIndex = nodeIdUpdater.updateNodeIndex(
- binnedFeatures = point.datum.binnedFeatures,
- bins = bins)
- node(treeId) = newNodeIndex
- }
-
- treeId += 1
- }
-
- node
- }
- }
-
- // Keep on persisting new ones.
- nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK)
- rddUpdateCount += 1
-
- // Handle checkpointing if the directory is not None.
- if (nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty &&
- (rddUpdateCount % checkpointInterval) == 0) {
- // Let's see if we can delete previous checkpoints.
- var canDelete = true
- while (checkpointQueue.size > 1 && canDelete) {
- // We can delete the oldest checkpoint iff
- // the next checkpoint actually exists in the file system.
- if (checkpointQueue.get(1).get.getCheckpointFile.isDefined) {
- val old = checkpointQueue.dequeue()
-
- // Since the old checkpoint is not deleted by Spark,
- // we'll manually delete it here.
- val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
- fs.delete(new Path(old.getCheckpointFile.get), true)
- } else {
- canDelete = false
- }
- }
-
- nodeIdsForInstances.checkpoint()
- checkpointQueue.enqueue(nodeIdsForInstances)
- }
- }
-
- /**
- * Call this after training is finished to delete any remaining checkpoints.
- */
- def deleteAllCheckpoints(): Unit = {
- while (checkpointQueue.nonEmpty) {
- val old = checkpointQueue.dequeue()
- for (checkpointFile <- old.getCheckpointFile) {
- val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
- fs.delete(new Path(checkpointFile), true)
- }
- }
- if (prevNodeIdsForInstances != null) {
- // Unpersist the previous one if one exists.
- prevNodeIdsForInstances.unpersist()
- }
- }
-}
-
-private[spark] object NodeIdCache {
- /**
- * Initialize the node Id cache with initial node Id values.
- * @param data The RDD of training rows.
- * @param numTrees The number of trees that we want to create cache for.
- * @param checkpointInterval The checkpointing interval
- * (how often should the cache be checkpointed.).
- * @param initVal The initial values in the cache.
- * @return A node Id cache containing an RDD of initial root node Indices.
- */
- def init(
- data: RDD[BaggedPoint[TreePoint]],
- numTrees: Int,
- checkpointInterval: Int,
- initVal: Int = 1): NodeIdCache = {
- new NodeIdCache(
- data.map(_ => Array.fill[Int](numTrees)(initVal)),
- checkpointInterval)
- }
-}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
deleted file mode 100644
index 21919d69a3..0000000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
+++ /dev/null
@@ -1,150 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.impl
-
-import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.model.Bin
-import org.apache.spark.rdd.RDD
-
-
-/**
- * Internal representation of LabeledPoint for DecisionTree.
- * This bins feature values based on a subsampled of data as follows:
- * (a) Continuous features are binned into ranges.
- * (b) Unordered categorical features are binned based on subsets of feature values.
- * "Unordered categorical features" are categorical features with low arity used in
- * multiclass classification.
- * (c) Ordered categorical features are binned based on feature values.
- * "Ordered categorical features" are categorical features with high arity,
- * or any categorical feature used in regression or binary classification.
- *
- * @param label Label from LabeledPoint
- * @param binnedFeatures Binned feature values.
- * Same length as LabeledPoint.features, but values are bin indices.
- */
-private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int])
- extends Serializable {
-}
-
-private[spark] object TreePoint {
-
- /**
- * Convert an input dataset into its TreePoint representation,
- * binning feature values in preparation for DecisionTree training.
- * @param input Input dataset.
- * @param bins Bins for features, of size (numFeatures, numBins).
- * @param metadata Learning and dataset metadata
- * @return TreePoint dataset representation
- */
- def convertToTreeRDD(
- input: RDD[LabeledPoint],
- bins: Array[Array[Bin]],
- metadata: DecisionTreeMetadata): RDD[TreePoint] = {
- // Construct arrays for featureArity for efficiency in the inner loop.
- val featureArity: Array[Int] = new Array[Int](metadata.numFeatures)
- var featureIndex = 0
- while (featureIndex < metadata.numFeatures) {
- featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0)
- featureIndex += 1
- }
- input.map { x =>
- TreePoint.labeledPointToTreePoint(x, bins, featureArity)
- }
- }
-
- /**
- * Convert one LabeledPoint into its TreePoint representation.
- * @param bins Bins for features, of size (numFeatures, numBins).
- * @param featureArity Array indexed by feature, with value 0 for continuous and numCategories
- * for categorical features.
- */
- private def labeledPointToTreePoint(
- labeledPoint: LabeledPoint,
- bins: Array[Array[Bin]],
- featureArity: Array[Int]): TreePoint = {
- val numFeatures = labeledPoint.features.size
- val arr = new Array[Int](numFeatures)
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex),
- bins)
- featureIndex += 1
- }
- new TreePoint(labeledPoint.label, arr)
- }
-
- /**
- * Find bin for one (labeledPoint, feature).
- *
- * @param featureArity 0 for continuous features; number of categories for categorical features.
- * @param bins Bins for features, of size (numFeatures, numBins).
- */
- private def findBin(
- featureIndex: Int,
- labeledPoint: LabeledPoint,
- featureArity: Int,
- bins: Array[Array[Bin]]): Int = {
-
- /**
- * Binary search helper method for continuous feature.
- */
- def binarySearchForBins(): Int = {
- val binForFeatures = bins(featureIndex)
- val feature = labeledPoint.features(featureIndex)
- var left = 0
- var right = binForFeatures.length - 1
- while (left <= right) {
- val mid = left + (right - left) / 2
- val bin = binForFeatures(mid)
- val lowThreshold = bin.lowSplit.threshold
- val highThreshold = bin.highSplit.threshold
- if ((lowThreshold < feature) && (highThreshold >= feature)) {
- return mid
- } else if (lowThreshold >= feature) {
- right = mid - 1
- } else {
- left = mid + 1
- }
- }
- -1
- }
-
- if (featureArity == 0) {
- // Perform binary search for finding bin for continuous features.
- val binIndex = binarySearchForBins()
- if (binIndex == -1) {
- throw new RuntimeException("No bin was found for continuous feature." +
- " This error can occur when given invalid data values (such as NaN)." +
- s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}")
- }
- binIndex
- } else {
- // Categorical feature bins are indexed by feature values.
- val featureValue = labeledPoint.features(featureIndex)
- if (featureValue < 0 || featureValue >= featureArity) {
- throw new IllegalArgumentException(
- s"DecisionTree given invalid data:" +
- s" Feature $featureIndex is categorical with values in" +
- s" {0,...,${featureArity - 1}," +
- s" but a data point gives it value $featureValue.\n" +
- " Bad data point: " + labeledPoint.toString)
- }
- featureValue.toInt
- }
- }
-}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 13aff11007..ff7700d2d1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -85,7 +85,7 @@ object Entropy extends Impurity {
* Note: Instances of this class do not hold the data; they operate on views of the data.
* @param numClasses Number of classes for label.
*/
-private[tree] class EntropyAggregator(numClasses: Int)
+private[spark] class EntropyAggregator(numClasses: Int)
extends ImpurityAggregator(numClasses) with Serializable {
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index 39c7f9c3be..58dc79b739 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -81,7 +81,7 @@ object Gini extends Impurity {
* Note: Instances of this class do not hold the data; they operate on views of the data.
* @param numClasses Number of classes for label.
*/
-private[tree] class GiniAggregator(numClasses: Int)
+private[spark] class GiniAggregator(numClasses: Int)
extends ImpurityAggregator(numClasses) with Serializable {
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index 92d74a1b83..2423516123 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -71,7 +71,7 @@ object Variance extends Impurity {
* in order to compute impurity from a sample.
* Note: Instances of this class do not hold the data; they operate on views of the data.
*/
-private[tree] class VarianceAggregator()
+private[spark] class VarianceAggregator()
extends ImpurityAggregator(statsSize = 3) with Serializable {
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
deleted file mode 100644
index 0cad473782..0000000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.model
-
-import org.apache.spark.mllib.tree.configuration.FeatureType._
-
-/**
- * Used for "binning" the feature values for faster best split calculation.
- *
- * For a continuous feature, the bin is determined by a low and a high split,
- * where an example with featureValue falls into the bin s.t.
- * lowSplit.threshold < featureValue <= highSplit.threshold.
- *
- * For ordered categorical features, there is a 1-1-1 correspondence between
- * bins, splits, and feature values. The bin is determined by category/feature value.
- * However, the bins are not necessarily ordered by feature value;
- * they are ordered using impurity.
- *
- * For unordered categorical features, there is a 1-1 correspondence between bins, splits,
- * where bins and splits correspond to subsets of feature values (in highSplit.categories).
- * An unordered feature with k categories uses (1 << k - 1) - 1 bins, corresponding to all
- * partitionings of categories into 2 disjoint, non-empty sets.
- *
- * @param lowSplit signifying the lower threshold for the continuous feature to be
- * accepted in the bin
- * @param highSplit signifying the upper threshold for the continuous feature to be
- * accepted in the bin
- * @param featureType type of feature -- categorical or continuous
- * @param category categorical label value accepted in the bin for ordered features
- */
-private[tree]
-case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index ea68ff64a8..a87f8a6cde 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -156,7 +156,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
feature: Int,
threshold: Double,
featureType: Int,
- categories: Seq[Double]) { // TODO: Change to List once SPARK-3365 is fixed
+ categories: Seq[Double]) {
def toSplit: Split = {
new Split(feature, threshold, FeatureType(featureType), categories.toList)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index c3b1d5cdd7..774170ff40 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -67,42 +67,14 @@ object MLUtils {
path: String,
numFeatures: Int,
minPartitions: Int): RDD[LabeledPoint] = {
- val parsed = sc.textFile(path, minPartitions)
- .map(_.trim)
- .filter(line => !(line.isEmpty || line.startsWith("#")))
- .map { line =>
- val items = line.split(' ')
- val label = items.head.toDouble
- val (indices, values) = items.tail.filter(_.nonEmpty).map { item =>
- val indexAndValue = item.split(':')
- val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based.
- val value = indexAndValue(1).toDouble
- (index, value)
- }.unzip
-
- // check if indices are one-based and in ascending order
- var previous = -1
- var i = 0
- val indicesLength = indices.length
- while (i < indicesLength) {
- val current = indices(i)
- require(current > previous, s"indices should be one-based and in ascending order;"
- + " found current=$current, previous=$previous; line=\"$line\"")
- previous = current
- i += 1
- }
-
- (label, indices.toArray, values.toArray)
- }
+ val parsed = parseLibSVMFile(sc, path, minPartitions)
// Determine number of features.
val d = if (numFeatures > 0) {
numFeatures
} else {
parsed.persist(StorageLevel.MEMORY_ONLY)
- parsed.map { case (label, indices, values) =>
- indices.lastOption.getOrElse(0)
- }.reduce(math.max) + 1
+ computeNumFeatures(parsed)
}
parsed.map { case (label, indices, values) =>
@@ -110,6 +82,47 @@ object MLUtils {
}
}
+ private[spark] def computeNumFeatures(rdd: RDD[(Double, Array[Int], Array[Double])]): Int = {
+ rdd.map { case (label, indices, values) =>
+ indices.lastOption.getOrElse(0)
+ }.reduce(math.max) + 1
+ }
+
+ private[spark] def parseLibSVMFile(
+ sc: SparkContext,
+ path: String,
+ minPartitions: Int): RDD[(Double, Array[Int], Array[Double])] = {
+ sc.textFile(path, minPartitions)
+ .map(_.trim)
+ .filter(line => !(line.isEmpty || line.startsWith("#")))
+ .map(parseLibSVMRecord)
+ }
+
+ private[spark] def parseLibSVMRecord(line: String): (Double, Array[Int], Array[Double]) = {
+ val items = line.split(' ')
+ val label = items.head.toDouble
+ val (indices, values) = items.tail.filter(_.nonEmpty).map { item =>
+ val indexAndValue = item.split(':')
+ val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based.
+ val value = indexAndValue(1).toDouble
+ (index, value)
+ }.unzip
+
+ // check if indices are one-based and in ascending order
+ var previous = -1
+ var i = 0
+ val indicesLength = indices.length
+ while (i < indicesLength) {
+ val current = indices(i)
+ require(current > previous, s"indices should be one-based and in ascending order;"
+ + " found current=$current, previous=$previous; line=\"$line\"")
+ previous = current
+ i += 1
+ }
+
+ (label, indices.toArray, values.toArray)
+ }
+
/**
* Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with the default number of
* partitions.
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
index d499d363f1..bc955f3cf6 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
@@ -63,7 +63,7 @@ public class JavaMultilayerPerceptronClassifierSuite implements Serializable {
MultilayerPerceptronClassifier mlpc = new MultilayerPerceptronClassifier()
.setLayers(new int[] {2, 5, 2})
.setBlockSize(1)
- .setSeed(11L)
+ .setSeed(123L)
.setMaxIter(100);
MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame);
Dataset<Row> result = model.transform(dataFrame);
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
index 75061464e5..5aec52ac72 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
@@ -22,6 +22,7 @@ import java.util.HashMap;
import java.util.Map;
import org.junit.After;
+import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
@@ -80,6 +81,24 @@ public class JavaRandomForestClassifierSuite implements Serializable {
for (String featureSubsetStrategy: RandomForestClassifier.supportedFeatureSubsetStrategies()) {
rf.setFeatureSubsetStrategy(featureSubsetStrategy);
}
+ String realStrategies[] = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
+ for (String strategy: realStrategies) {
+ rf.setFeatureSubsetStrategy(strategy);
+ }
+ String integerStrategies[] = {"1", "10", "100", "1000", "10000"};
+ for (String strategy: integerStrategies) {
+ rf.setFeatureSubsetStrategy(strategy);
+ }
+ String invalidStrategies[] = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"};
+ for (String strategy: invalidStrategies) {
+ try {
+ rf.setFeatureSubsetStrategy(strategy);
+ Assert.fail("Expected exception to be thrown for invalid strategies");
+ } catch (Exception e) {
+ Assert.assertTrue(e instanceof IllegalArgumentException);
+ }
+ }
+
RandomForestClassificationModel model = rf.fit(dataFrame);
model.transform(dataFrame);
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
index 65841182df..06f7fbb86e 100644
--- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
@@ -89,7 +89,7 @@ public class JavaTestParams extends JavaParams {
myDoubleParam_ = new DoubleParam(this, "myDoubleParam", "this is a double param",
ParamValidators.inRange(0.0, 1.0));
List<String> validStrings = Arrays.asList("a", "b");
- myStringParam_ = new Param<String>(this, "myStringParam", "this is a string param",
+ myStringParam_ = new Param<>(this, "myStringParam", "this is a string param",
ParamValidators.inArray(validStrings));
myDoubleArrayParam_ =
new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param");
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
index b6f793f6de..a8736669f7 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
@@ -22,6 +22,7 @@ import java.util.HashMap;
import java.util.Map;
import org.junit.After;
+import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
@@ -80,6 +81,24 @@ public class JavaRandomForestRegressorSuite implements Serializable {
for (String featureSubsetStrategy: RandomForestRegressor.supportedFeatureSubsetStrategies()) {
rf.setFeatureSubsetStrategy(featureSubsetStrategy);
}
+ String realStrategies[] = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
+ for (String strategy: realStrategies) {
+ rf.setFeatureSubsetStrategy(strategy);
+ }
+ String integerStrategies[] = {"1", "10", "100", "1000", "10000"};
+ for (String strategy: integerStrategies) {
+ rf.setFeatureSubsetStrategy(strategy);
+ }
+ String invalidStrategies[] = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"};
+ for (String strategy: invalidStrategies) {
+ try {
+ rf.setFeatureSubsetStrategy(strategy);
+ Assert.fail("Expected exception to be thrown for invalid strategies");
+ } catch (Exception e) {
+ Assert.assertTrue(e instanceof IllegalArgumentException);
+ }
+ }
+
RandomForestRegressionModel model = rf.fit(dataFrame);
model.transform(dataFrame);
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java
index c9e5ee22f3..62c6d9b7e3 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java
@@ -66,8 +66,8 @@ public class JavaStreamingLogisticRegressionSuite implements Serializable {
JavaDStream<LabeledPoint> training =
attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2);
List<Tuple2<Integer, Vector>> testBatch = Arrays.asList(
- new Tuple2<Integer, Vector>(10, Vectors.dense(1.0)),
- new Tuple2<Integer, Vector>(11, Vectors.dense(0.0)));
+ new Tuple2<>(10, Vectors.dense(1.0)),
+ new Tuple2<>(11, Vectors.dense(0.0)));
JavaPairDStream<Integer, Vector> test = JavaPairDStream.fromJavaDStream(
attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2));
StreamingLogisticRegressionWithSGD slr = new StreamingLogisticRegressionWithSGD()
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
index d644766d1e..62edbd3a29 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java
@@ -66,8 +66,8 @@ public class JavaStreamingKMeansSuite implements Serializable {
JavaDStream<Vector> training =
attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2);
List<Tuple2<Integer, Vector>> testBatch = Arrays.asList(
- new Tuple2<Integer, Vector>(10, Vectors.dense(1.0)),
- new Tuple2<Integer, Vector>(11, Vectors.dense(0.0)));
+ new Tuple2<>(10, Vectors.dense(1.0)),
+ new Tuple2<>(11, Vectors.dense(0.0)));
JavaPairDStream<Integer, Vector> test = JavaPairDStream.fromJavaDStream(
attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2));
StreamingKMeans skmeans = new StreamingKMeans()
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
index 34daf5fbde..8a67793abc 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
@@ -17,6 +17,7 @@
package org.apache.spark.mllib.fpm;
+import java.io.File;
import java.util.Arrays;
import java.util.List;
@@ -28,6 +29,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence;
+import org.apache.spark.util.Utils;
public class JavaPrefixSpanSuite {
private transient JavaSparkContext sc;
@@ -64,4 +66,39 @@ public class JavaPrefixSpanSuite {
long freq = freqSeq.freq();
}
}
+
+ @Test
+ public void runPrefixSpanSaveLoad() {
+ JavaRDD<List<List<Integer>>> sequences = sc.parallelize(Arrays.asList(
+ Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)),
+ Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)),
+ Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)),
+ Arrays.asList(Arrays.asList(6))
+ ), 2);
+ PrefixSpan prefixSpan = new PrefixSpan()
+ .setMinSupport(0.5)
+ .setMaxPatternLength(5);
+ PrefixSpanModel<Integer> model = prefixSpan.run(sequences);
+
+ File tempDir = Utils.createTempDir(
+ System.getProperty("java.io.tmpdir"), "JavaPrefixSpanSuite");
+ String outputPath = tempDir.getPath();
+
+ try {
+ model.save(sc.sc(), outputPath);
+ PrefixSpanModel newModel = PrefixSpanModel.load(sc.sc(), outputPath);
+ JavaRDD<FreqSequence<Integer>> freqSeqs = newModel.freqSequences().toJavaRDD();
+ List<FreqSequence<Integer>> localFreqSeqs = freqSeqs.collect();
+ Assert.assertEquals(5, localFreqSeqs.size());
+ // Check that each frequent sequence could be materialized.
+ for (PrefixSpan.FreqSequence<Integer> freqSeq: localFreqSeqs) {
+ List<List<Integer>> seq = freqSeq.javaSequence();
+ long freq = freqSeq.freq();
+ }
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+
+
+ }
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
index 77c8c6274f..4ba8e543a9 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
@@ -37,8 +37,8 @@ public class JavaVectorsSuite implements Serializable {
public void sparseArrayConstruction() {
@SuppressWarnings("unchecked")
Vector v = Vectors.sparse(3, Arrays.asList(
- new Tuple2<Integer, Double>(0, 2.0),
- new Tuple2<Integer, Double>(2, 3.0)));
+ new Tuple2<>(0, 2.0),
+ new Tuple2<>(2, 3.0)));
assertArrayEquals(new double[]{2.0, 0.0, 3.0}, v.toArray(), 0.0);
}
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java
index dbf6488d41..ea0ccd7448 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java
@@ -65,8 +65,8 @@ public class JavaStreamingLinearRegressionSuite implements Serializable {
JavaDStream<LabeledPoint> training =
attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2);
List<Tuple2<Integer, Vector>> testBatch = Arrays.asList(
- new Tuple2<Integer, Vector>(10, Vectors.dense(1.0)),
- new Tuple2<Integer, Vector>(11, Vectors.dense(0.0)));
+ new Tuple2<>(10, Vectors.dense(1.0)),
+ new Tuple2<>(11, Vectors.dense(0.0)));
JavaPairDStream<Integer, Vector> test = JavaPairDStream.fromJavaDStream(
attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2));
StreamingLinearRegressionWithSGD slr = new StreamingLinearRegressionWithSGD()
diff --git a/mllib/src/test/resources/log4j.properties b/mllib/src/test/resources/log4j.properties
index 75e3b53a09..fd51f8faf5 100644
--- a/mllib/src/test/resources/log4j.properties
+++ b/mllib/src/test/resources/log4j.properties
@@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
-log4j.logger.org.spark-project.jetty=WARN
+log4j.logger.org.spark_project.jetty=WARN
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index f3321fb5a1..a8c4ac6d05 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.ml.param.{IntParam, ParamMap}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -51,6 +51,12 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
val dataset3 = mock[DataFrame]
val dataset4 = mock[DataFrame]
+ when(dataset0.toDF).thenReturn(dataset0)
+ when(dataset1.toDF).thenReturn(dataset1)
+ when(dataset2.toDF).thenReturn(dataset2)
+ when(dataset3.toDF).thenReturn(dataset3)
+ when(dataset4.toDF).thenReturn(dataset4)
+
when(estimator0.copy(any[ParamMap])).thenReturn(estimator0)
when(model0.copy(any[ParamMap])).thenReturn(model0)
when(transformer1.copy(any[ParamMap])).thenReturn(transformer1)
@@ -213,7 +219,7 @@ class WritableStage(override val uid: String) extends Transformer with MLWritabl
override def write: MLWriter = new DefaultParamsWriter(this)
- override def transform(dataset: DataFrame): DataFrame = dataset
+ override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF
override def transformSchema(schema: StructType): StructType = schema
}
@@ -234,7 +240,7 @@ class UnWritableStage(override val uid: String) extends Transformer {
override def copy(extra: ParamMap): UnWritableStage = defaultCopy(extra)
- override def transform(dataset: DataFrame): DataFrame = dataset
+ override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF
override def transformSchema(schema: StructType): StructType = schema
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala
index 1292e57d7c..dc91fc5f9e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala
@@ -42,7 +42,7 @@ class ANNSuite extends SparkFunSuite with MLlibTestSparkContext {
val dataSample = rddData.first()
val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size
val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false)
- val initialWeights = FeedForwardModel(topology, 23124).weights()
+ val initialWeights = FeedForwardModel(topology, 23124).weights
val trainer = new FeedForwardTrainer(topology, 2, 1)
trainer.setWeights(initialWeights)
trainer.LBFGSOptimizer.setNumIterations(20)
@@ -76,10 +76,11 @@ class ANNSuite extends SparkFunSuite with MLlibTestSparkContext {
val dataSample = rddData.first()
val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size
val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false)
- val initialWeights = FeedForwardModel(topology, 23124).weights()
+ val initialWeights = FeedForwardModel(topology, 23124).weights
val trainer = new FeedForwardTrainer(topology, 2, 2)
- trainer.SGDOptimizer.setNumIterations(2000)
- trainer.setWeights(initialWeights)
+ // TODO: add a test for SGD
+ trainer.LBFGSOptimizer.setConvergenceTol(1e-4).setNumIterations(20)
+ trainer.setWeights(initialWeights).setStackSize(1)
val model = trainer.train(rddData)
val predictionAndLabels = rddData.map { case (input, label) =>
(model.predict(input), label)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala
new file mode 100644
index 0000000000..04cc426c40
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.ann
+
+import breeze.linalg.{DenseMatrix => BDM}
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class GradientSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ test("Gradient computation against numerical differentiation") {
+ val input = new BDM[Double](3, 1, Array(1.0, 1.0, 1.0))
+ // output must contain zeros and one 1 for SoftMax
+ val target = new BDM[Double](2, 1, Array(0.0, 1.0))
+ val topology = FeedForwardTopology.multiLayerPerceptron(Array(3, 4, 2), softmaxOnTop = false)
+ val layersWithErrors = Seq(
+ new SigmoidLayerWithSquaredError(),
+ new SoftmaxLayerWithCrossEntropyLoss()
+ )
+ // check all layers that provide loss computation
+ // 1) compute loss and gradient given the model and initial weights
+ // 2) modify weights with small number epsilon (per dimension i)
+ // 3) compute new loss
+ // 4) ((newLoss - loss) / epsilon) should be close to the i-th component of the gradient
+ for (layerWithError <- layersWithErrors) {
+ topology.layers(topology.layers.length - 1) = layerWithError
+ val model = topology.model(seed = 12L)
+ val weights = model.weights.toArray
+ val numWeights = weights.size
+ val gradient = Vectors.dense(Array.fill[Double](numWeights)(0.0))
+ val loss = model.computeGradient(input, target, gradient, 1)
+ val eps = 1e-4
+ var i = 0
+ val tol = 1e-4
+ while (i < numWeights) {
+ val originalValue = weights(i)
+ weights(i) += eps
+ val newModel = topology.model(Vectors.dense(weights))
+ val newLoss = computeLoss(input, target, newModel)
+ val derivativeEstimate = (newLoss - loss) / eps
+ assert(math.abs(gradient(i) - derivativeEstimate) < tol, "Layer failed gradient check: " +
+ layerWithError.getClass)
+ weights(i) = originalValue
+ i += 1
+ }
+ }
+ }
+
+ private def computeLoss(input: BDM[Double], target: BDM[Double], model: TopologyModel): Double = {
+ val outputs = model.forward(input)
+ model.layerModels.last match {
+ case layerWithLoss: LossFunction =>
+ layerWithLoss.loss(outputs.last, target, new BDM[Double](target.rows, target.cols))
+ case _ =>
+ throw new UnsupportedOperationException("Top layer is required to have loss." +
+ " Failed layer:" + model.layerModels.last.getClass)
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 2b07524815..fe839e15e9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -27,8 +27,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{DataFrame, Row}
class DecisionTreeClassifierSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -176,7 +175,7 @@ class DecisionTreeClassifierSuite
}
test("Multiclass classification tree with 10-ary (ordered) categorical features," +
- " with just enough bins") {
+ " with just enough bins") {
val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD
val dt = new DecisionTreeClassifier()
.setImpurity("Gini")
@@ -273,7 +272,7 @@ class DecisionTreeClassifierSuite
))
val df = TreeTests.setMetadata(data, Map(0 -> 1), 2)
val dt = new DecisionTreeClassifier().setMaxDepth(3)
- val model = dt.fit(df)
+ dt.fit(df)
}
test("Use soft prediction for binary classification with ordered categorical features") {
@@ -335,6 +334,14 @@ class DecisionTreeClassifierSuite
assert(importances.toArray.forall(_ >= 0.0))
}
+ test("should support all NumericType labels and not support other types") {
+ val dt = new DecisionTreeClassifier().setMaxDepth(1)
+ MLTestingUtils.checkNumericTypes[DecisionTreeClassificationModel, DecisionTreeClassifier](
+ dt, isClassification = true, sqlContext) { (expected, actual) =>
+ TreeTests.checkEqual(expected, actual)
+ }
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index f3680ed044..7e6aec6b1b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -31,11 +31,11 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.util.Utils
-
/**
* Test suite for [[GBTClassifier]].
*/
-class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
+class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
+ with DefaultReadWriteTest {
import GBTClassifierSuite.compareAPIs
@@ -102,6 +102,14 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
Utils.deleteRecursively(tempDir)
}
+ test("should support all NumericType labels and not support other types") {
+ val gbt = new GBTClassifier().setMaxDepth(1)
+ MLTestingUtils.checkNumericTypes[GBTClassificationModel, GBTClassifier](
+ gbt, isClassification = true, sqlContext) { (expected, actual) =>
+ TreeTests.checkEqual(expected, actual)
+ }
+ }
+
// TODO: Reinstate test once runWithValidation is implemented SPARK-7132
/*
test("runWithValidation stops early and performs better on a validation dataset") {
@@ -121,30 +129,51 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
*/
/////////////////////////////////////////////////////////////////////////////
+ // Tests of feature importance
+ /////////////////////////////////////////////////////////////////////////////
+ test("Feature importance with toy data") {
+ val numClasses = 2
+ val gbt = new GBTClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(3)
+ .setMaxIter(5)
+ .setSubsamplingRate(1.0)
+ .setStepSize(0.5)
+ .setSeed(123)
+
+ // In this data, feature 1 is very important.
+ val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+ val categoricalFeatures = Map.empty[Int, Int]
+ val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
+
+ val importances = gbt.fit(df).featureImportances
+ val mostImportantFeature = importances.argmax
+ assert(mostImportantFeature === 1)
+ assert(importances.toArray.sum === 1.0)
+ assert(importances.toArray.forall(_ >= 0.0))
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: Reinstate test once save/load are implemented SPARK-6725
- /*
test("model save/load") {
- val tempDir = Utils.createTempDir()
- val path = tempDir.toURI.toString
-
- val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray
- val treeWeights = Array(0.1, 0.3, 1.1)
- val oldModel = new OldGBTModel(OldAlgo.Classification, trees, treeWeights)
- val newModel = GBTClassificationModel.fromOld(oldModel)
-
- // Save model, load it back, and compare.
- try {
- newModel.save(sc, path)
- val sameNewModel = GBTClassificationModel.load(sc, path)
- TreeTests.checkEqual(newModel, sameNewModel)
- } finally {
- Utils.deleteRecursively(tempDir)
+ def checkModelData(
+ model: GBTClassificationModel,
+ model2: GBTClassificationModel): Unit = {
+ TreeTests.checkEqual(model, model2)
+ assert(model.numFeatures === model2.numFeatures)
}
+
+ val gbt = new GBTClassifier()
+ val rdd = TreeTests.getTreeReadWriteData(sc)
+
+ val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "logistic")
+
+ val continuousData: DataFrame =
+ TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
+ testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData)
}
- */
}
private object GBTClassifierSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index afeeaf7fb5..48db428130 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -29,13 +29,13 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.lit
class LogisticRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
- @transient var dataset: DataFrame = _
+ @transient var dataset: Dataset[_] = _
@transient var binaryDataset: DataFrame = _
private val eps: Double = 1e-5
@@ -103,7 +103,7 @@ class LogisticRegressionSuite
assert(model.hasSummary)
// Validate that we re-insert a probability column for evaluation
val fieldNames = model.summary.predictions.schema.fieldNames
- assert((dataset.schema.fieldNames.toSet).subsetOf(
+ assert(dataset.schema.fieldNames.toSet.subsetOf(
fieldNames.toSet))
assert(fieldNames.exists(s => s.startsWith("probability_")))
}
@@ -934,6 +934,15 @@ class LogisticRegressionSuite
testEstimatorAndModelReadWrite(lr, dataset, LogisticRegressionSuite.allParamSettings,
checkModelData)
}
+
+ test("should support all NumericType labels and not support other types") {
+ val lr = new LogisticRegression().setMaxIter(1)
+ MLTestingUtils.checkNumericTypes[LogisticRegressionModel, LogisticRegression](
+ lr, isClassification = true, sqlContext) { (expected, actual) =>
+ assert(expected.intercept === actual.intercept)
+ assert(expected.coefficients.toArray === actual.coefficients.toArray)
+ }
+ }
}
object LogisticRegressionSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index 5df8e6a847..80547fad6a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -19,18 +19,19 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
class MultilayerPerceptronClassifierSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
- @transient var dataset: DataFrame = _
+ @transient var dataset: Dataset[_] = _
override def beforeAll(): Unit = {
super.beforeAll()
@@ -43,12 +44,29 @@ class MultilayerPerceptronClassifierSuite
).toDF("features", "label")
}
+ test("Input Validation") {
+ val mlpc = new MultilayerPerceptronClassifier()
+ intercept[IllegalArgumentException] {
+ mlpc.setLayers(Array[Int]())
+ }
+ intercept[IllegalArgumentException] {
+ mlpc.setLayers(Array[Int](1))
+ }
+ intercept[IllegalArgumentException] {
+ mlpc.setLayers(Array[Int](0, 1))
+ }
+ intercept[IllegalArgumentException] {
+ mlpc.setLayers(Array[Int](1, 0))
+ }
+ mlpc.setLayers(Array[Int](1, 1))
+ }
+
test("XOR function learning as binary classification problem with two outputs.") {
val layers = Array[Int](2, 5, 2)
val trainer = new MultilayerPerceptronClassifier()
.setLayers(layers)
.setBlockSize(1)
- .setSeed(11L)
+ .setSeed(123L)
.setMaxIter(100)
val model = trainer.fit(dataset)
val result = model.transform(dataset)
@@ -58,7 +76,29 @@ class MultilayerPerceptronClassifierSuite
}
}
- // TODO: implement a more rigorous test
+ test("Test setWeights by training restart") {
+ val dataFrame = sqlContext.createDataFrame(Seq(
+ (Vectors.dense(0.0, 0.0), 0.0),
+ (Vectors.dense(0.0, 1.0), 1.0),
+ (Vectors.dense(1.0, 0.0), 1.0),
+ (Vectors.dense(1.0, 1.0), 0.0))
+ ).toDF("features", "label")
+ val layers = Array[Int](2, 5, 2)
+ val trainer = new MultilayerPerceptronClassifier()
+ .setLayers(layers)
+ .setBlockSize(1)
+ .setSeed(12L)
+ .setMaxIter(1)
+ .setTol(1e-6)
+ val initialWeights = trainer.fit(dataFrame).weights
+ trainer.setWeights(initialWeights.copy)
+ val weights1 = trainer.fit(dataFrame).weights
+ trainer.setWeights(initialWeights.copy)
+ val weights2 = trainer.fit(dataFrame).weights
+ assert(weights1 ~== weights2 absTol 10e-5,
+ "Training should produce the same weights given equal initial weights and number of steps")
+ }
+
test("3 class classification with 2 hidden layers") {
val nPoints = 1000
@@ -123,4 +163,15 @@ class MultilayerPerceptronClassifierSuite
assert(newMlpModel.layers === mlpModel.layers)
assert(newMlpModel.weights === mlpModel.weights)
}
+
+ test("should support all NumericType labels and not support other types") {
+ val layers = Array(3, 2)
+ val mpc = new MultilayerPerceptronClassifier().setLayers(layers).setMaxIter(1)
+ MLTestingUtils.checkNumericTypes[
+ MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier](
+ mpc, isClassification = true, sqlContext) { (expected, actual) =>
+ assert(expected.layers === actual.layers)
+ assert(expected.weights === actual.weights)
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 082a6bcd21..80a46fc70c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -21,17 +21,17 @@ import breeze.linalg.{Vector => BV}
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.classification.NaiveBayes.{Bernoulli, Multinomial}
import org.apache.spark.mllib.classification.NaiveBayesSuite._
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
- @transient var dataset: DataFrame = _
+ @transient var dataset: Dataset[_] = _
override def beforeAll(): Unit = {
super.beforeAll()
@@ -86,7 +86,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
model: NaiveBayesModel,
modelType: String): Unit = {
featureAndProbabilities.collect().foreach {
- case Row(features: Vector, probability: Vector) => {
+ case Row(features: Vector, probability: Vector) =>
assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10)
val expected = modelType match {
case Multinomial =>
@@ -97,7 +97,6 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
throw new UnknownError(s"Invalid modelType: $modelType.")
}
assert(probability ~== expected relTol 1.0e-10)
- }
}
}
@@ -185,6 +184,15 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
val nb = new NaiveBayes()
testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData)
}
+
+ test("should support all NumericType labels and not support other types") {
+ val nb = new NaiveBayes()
+ MLTestingUtils.checkNumericTypes[NaiveBayesModel, NaiveBayes](
+ nb, isClassification = true, sqlContext) { (expected, actual) =>
+ assert(expected.pi === actual.pi)
+ assert(expected.theta === actual.theta)
+ }
+ }
}
object NaiveBayesSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 2ae74a2090..f3e8fd11b2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
-import org.apache.spark.ml.util.{MetadataUtils, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTestingUtils}
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
@@ -30,12 +30,12 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.Metadata
-class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
+class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
- @transient var dataset: DataFrame = _
+ @transient var dataset: Dataset[_] = _
@transient var rdd: RDD[LabeledPoint] = _
override def beforeAll(): Unit = {
@@ -74,7 +74,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
// copied model must have the same parent.
MLTestingUtils.checkCopy(ovaModel)
- assert(ovaModel.models.size === numClasses)
+ assert(ovaModel.models.length === numClasses)
val transformedDataset = ovaModel.transform(dataset)
// check for label metadata in prediction col
@@ -160,6 +160,84 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
require(m.getThreshold === 0.1, "copy should handle extra model params")
}
}
+
+ test("read/write: OneVsRest") {
+ val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01)
+
+ val ova = new OneVsRest()
+ .setClassifier(lr)
+ .setLabelCol("myLabel")
+ .setFeaturesCol("myFeature")
+ .setPredictionCol("myPrediction")
+
+ val ova2 = testDefaultReadWrite(ova, testParams = false)
+ assert(ova.uid === ova2.uid)
+ assert(ova.getFeaturesCol === ova2.getFeaturesCol)
+ assert(ova.getLabelCol === ova2.getLabelCol)
+ assert(ova.getPredictionCol === ova2.getPredictionCol)
+
+ ova2.getClassifier match {
+ case lr2: LogisticRegression =>
+ assert(lr.uid === lr2.uid)
+ assert(lr.getMaxIter === lr2.getMaxIter)
+ assert(lr.getRegParam === lr2.getRegParam)
+ case other =>
+ throw new AssertionError(s"Loaded OneVsRest expected classifier of type" +
+ s" LogisticRegression but found ${other.getClass.getName}")
+ }
+ }
+
+ test("read/write: OneVsRestModel") {
+ def checkModelData(model: OneVsRestModel, model2: OneVsRestModel): Unit = {
+ assert(model.uid === model2.uid)
+ assert(model.getFeaturesCol === model2.getFeaturesCol)
+ assert(model.getLabelCol === model2.getLabelCol)
+ assert(model.getPredictionCol === model2.getPredictionCol)
+
+ val classifier = model.getClassifier.asInstanceOf[LogisticRegression]
+
+ model2.getClassifier match {
+ case lr2: LogisticRegression =>
+ assert(classifier.uid === lr2.uid)
+ assert(classifier.getMaxIter === lr2.getMaxIter)
+ assert(classifier.getRegParam === lr2.getRegParam)
+ case other =>
+ throw new AssertionError(s"Loaded OneVsRestModel expected classifier of type" +
+ s" LogisticRegression but found ${other.getClass.getName}")
+ }
+
+ assert(model.labelMetadata === model2.labelMetadata)
+ model.models.zip(model2.models).foreach {
+ case (lrModel1: LogisticRegressionModel, lrModel2: LogisticRegressionModel) =>
+ assert(lrModel1.uid === lrModel2.uid)
+ assert(lrModel1.coefficients === lrModel2.coefficients)
+ assert(lrModel1.intercept === lrModel2.intercept)
+ case other =>
+ throw new AssertionError(s"Loaded OneVsRestModel expected model of type" +
+ s" LogisticRegressionModel but found ${other.getClass.getName}")
+ }
+ }
+
+ val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01)
+ val ova = new OneVsRest().setClassifier(lr)
+ val ovaModel = ova.fit(dataset)
+ val newOvaModel = testDefaultReadWrite(ovaModel, testParams = false)
+ checkModelData(ovaModel, newOvaModel)
+ }
+
+ test("should support all NumericType labels and not support other types") {
+ val ovr = new OneVsRest().setClassifier(new LogisticRegression().setMaxIter(1))
+ MLTestingUtils.checkNumericTypes[OneVsRestModel, OneVsRest](
+ ovr, isClassification = true, sqlContext) { (expected, actual) =>
+ val expectedModels = expected.models.map(m => m.asInstanceOf[LogisticRegressionModel])
+ val actualModels = actual.models.map(m => m.asInstanceOf[LogisticRegressionModel])
+ assert(expectedModels.length === actualModels.length)
+ expectedModels.zip(actualModels).foreach { case (e, a) =>
+ assert(e.intercept === a.intercept)
+ assert(e.coefficients.toArray === a.coefficients.toArray)
+ }
+ }
+ }
}
private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {
@@ -168,7 +246,7 @@ private class MockLogisticRegression(uid: String) extends LogisticRegression(uid
setMaxIter(1)
- override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = {
+ override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = {
val labelSchema = dataset.schema($(labelCol))
// check for label attribute propagation.
assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index b896099e31..aaaa429103 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
@@ -34,7 +34,8 @@ import org.apache.spark.sql.{DataFrame, Row}
/**
* Test suite for [[RandomForestClassifier]].
*/
-class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
+class RandomForestClassifierSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
import RandomForestClassifierSuite.compareAPIs
@@ -178,31 +179,36 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
assert(importances.toArray.forall(_ >= 0.0))
}
+ test("should support all NumericType labels and not support other types") {
+ val rf = new RandomForestClassifier().setMaxDepth(1)
+ MLTestingUtils.checkNumericTypes[RandomForestClassificationModel, RandomForestClassifier](
+ rf, isClassification = true, sqlContext) { (expected, actual) =>
+ TreeTests.checkEqual(expected, actual)
+ }
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: Reinstate test once save/load are implemented SPARK-6725
- /*
- test("model save/load") {
- val tempDir = Utils.createTempDir()
- val path = tempDir.toURI.toString
-
- val trees =
- Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Classification)).toArray
- val oldModel = new OldRandomForestModel(OldAlgo.Classification, trees)
- val newModel = RandomForestClassificationModel.fromOld(oldModel)
-
- // Save model, load it back, and compare.
- try {
- newModel.save(sc, path)
- val sameNewModel = RandomForestClassificationModel.load(sc, path)
- TreeTests.checkEqual(newModel, sameNewModel)
- } finally {
- Utils.deleteRecursively(tempDir)
+ test("read/write") {
+ def checkModelData(
+ model: RandomForestClassificationModel,
+ model2: RandomForestClassificationModel): Unit = {
+ TreeTests.checkEqual(model, model2)
+ assert(model.numFeatures === model2.numFeatures)
+ assert(model.numClasses === model2.numClasses)
}
+
+ val rf = new RandomForestClassifier().setNumTrees(2)
+ val rdd = TreeTests.getTreeReadWriteData(sc)
+
+ val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "entropy")
+
+ val continuousData: DataFrame =
+ TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
+ testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData)
}
- */
}
private object RandomForestClassifierSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index b719a8c7e7..e641d79c17 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -18,13 +18,15 @@
package org.apache.spark.ml.clustering
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
-class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
+class BisectingKMeansSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
final val k = 5
- @transient var dataset: DataFrame = _
+ @transient var dataset: Dataset[_] = _
override def beforeAll(): Unit = {
super.beforeAll()
@@ -84,4 +86,22 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(model.computeCost(dataset) < 0.1)
assert(model.hasParent)
}
+
+ test("read/write") {
+ def checkModelData(model: BisectingKMeansModel, model2: BisectingKMeansModel): Unit = {
+ assert(model.clusterCenters === model2.clusterCenters)
+ }
+ val bisectingKMeans = new BisectingKMeans()
+ testEstimatorAndModelReadWrite(
+ bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, checkModelData)
+ }
+}
+
+object BisectingKMeansSuite {
+ val allParamSettings: Map[String, Any] = Map(
+ "k" -> 3,
+ "maxIter" -> 2,
+ "seed" -> -1L,
+ "minDivisibleClusterSize" -> 2.0
+ )
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
new file mode 100644
index 0000000000..1a274aea29
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.clustering
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+
+class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
+ with DefaultReadWriteTest {
+
+ final val k = 5
+ @transient var dataset: Dataset[_] = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k)
+ }
+
+ test("default parameters") {
+ val gm = new GaussianMixture()
+
+ assert(gm.getK === 2)
+ assert(gm.getFeaturesCol === "features")
+ assert(gm.getPredictionCol === "prediction")
+ assert(gm.getMaxIter === 100)
+ assert(gm.getTol === 0.01)
+ }
+
+ test("set parameters") {
+ val gm = new GaussianMixture()
+ .setK(9)
+ .setFeaturesCol("test_feature")
+ .setPredictionCol("test_prediction")
+ .setProbabilityCol("test_probability")
+ .setMaxIter(33)
+ .setSeed(123)
+ .setTol(1e-3)
+
+ assert(gm.getK === 9)
+ assert(gm.getFeaturesCol === "test_feature")
+ assert(gm.getPredictionCol === "test_prediction")
+ assert(gm.getProbabilityCol === "test_probability")
+ assert(gm.getMaxIter === 33)
+ assert(gm.getSeed === 123)
+ assert(gm.getTol === 1e-3)
+ }
+
+ test("parameters validation") {
+ intercept[IllegalArgumentException] {
+ new GaussianMixture().setK(1)
+ }
+ }
+
+ test("fit, transform, and summary") {
+ val predictionColName = "gm_prediction"
+ val probabilityColName = "gm_probability"
+ val gm = new GaussianMixture().setK(k).setMaxIter(2).setPredictionCol(predictionColName)
+ .setProbabilityCol(probabilityColName).setSeed(1)
+ val model = gm.fit(dataset)
+ assert(model.hasParent)
+ assert(model.weights.length === k)
+ assert(model.gaussians.length === k)
+
+ val transformed = model.transform(dataset)
+ val expectedColumns = Array("features", predictionColName, probabilityColName)
+ expectedColumns.foreach { column =>
+ assert(transformed.columns.contains(column))
+ }
+
+ // Check validity of model summary
+ val numRows = dataset.count()
+ assert(model.hasSummary)
+ val summary: GaussianMixtureSummary = model.summary
+ assert(summary.predictionCol === predictionColName)
+ assert(summary.probabilityCol === probabilityColName)
+ assert(summary.featuresCol === "features")
+ assert(summary.predictions.count() === numRows)
+ for (c <- Array(predictionColName, probabilityColName, "features")) {
+ assert(summary.predictions.columns.contains(c))
+ }
+ assert(summary.cluster.columns === Array(predictionColName))
+ assert(summary.probability.columns === Array(probabilityColName))
+ val clusterSizes = summary.clusterSizes
+ assert(clusterSizes.length === k)
+ assert(clusterSizes.sum === numRows)
+ assert(clusterSizes.forall(_ >= 0))
+ }
+
+ test("read/write") {
+ def checkModelData(model: GaussianMixtureModel, model2: GaussianMixtureModel): Unit = {
+ assert(model.weights === model2.weights)
+ assert(model.gaussians.map(_.mu) === model2.gaussians.map(_.mu))
+ assert(model.gaussians.map(_.sigma) === model2.gaussians.map(_.sigma))
+ }
+ val gm = new GaussianMixture()
+ testEstimatorAndModelReadWrite(gm, dataset,
+ GaussianMixtureSuite.allParamSettings, checkModelData)
+ }
+}
+
+object GaussianMixtureSuite {
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ */
+ val allParamSettings: Map[String, Any] = Map(
+ "predictionCol" -> "myPrediction",
+ "probabilityCol" -> "myProbability",
+ "k" -> 3,
+ "maxIter" -> 2,
+ "tol" -> 0.01
+ )
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index c684bc11cc..2ca386e422 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -22,14 +22,14 @@ import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, SQLContext}
private[clustering] case class TestRow(features: Vector)
class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
final val k = 5
- @transient var dataset: DataFrame = _
+ @transient var dataset: Dataset[_] = _
override def beforeAll(): Unit = {
super.beforeAll()
@@ -82,7 +82,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
}
}
- test("fit & transform") {
+ test("fit, transform, and summary") {
val predictionColName = "kmeans_prediction"
val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
val model = kmeans.fit(dataset)
@@ -99,6 +99,22 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(clusters === Set(0, 1, 2, 3, 4))
assert(model.computeCost(dataset) < 0.1)
assert(model.hasParent)
+
+ // Check validity of model summary
+ val numRows = dataset.count()
+ assert(model.hasSummary)
+ val summary: KMeansSummary = model.summary
+ assert(summary.predictionCol === predictionColName)
+ assert(summary.featuresCol === "features")
+ assert(summary.predictions.count() === numRows)
+ for (c <- Array(predictionColName, "features")) {
+ assert(summary.predictions.columns.contains(c))
+ }
+ assert(summary.cluster.columns === Array(predictionColName))
+ val clusterSizes = summary.clusterSizes
+ assert(clusterSizes.length === k)
+ assert(clusterSizes.sum === numRows)
+ assert(clusterSizes.forall(_ >= 0))
}
test("read/write") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index dd3f4c6e53..17d6e9fc2e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -17,12 +17,14 @@
package org.apache.spark.ml.clustering
+import org.apache.hadoop.fs.{FileSystem, Path}
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
object LDASuite {
@@ -62,7 +64,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
val k: Int = 5
val vocabSize: Int = 30
- @transient var dataset: DataFrame = _
+ @transient var dataset: Dataset[_] = _
override def beforeAll(): Unit = {
super.beforeAll()
@@ -261,4 +263,41 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
testEstimatorAndModelReadWrite(lda, dataset,
LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData)
}
+
+ test("EM LDA checkpointing: save last checkpoint") {
+ // Checkpoint dir is set by MLlibTestSparkContext
+ val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3).setCheckpointInterval(1)
+ val model_ = lda.fit(dataset)
+ assert(model_.isInstanceOf[DistributedLDAModel])
+ val model = model_.asInstanceOf[DistributedLDAModel]
+
+ // There should be 1 checkpoint remaining.
+ assert(model.getCheckpointFiles.length === 1)
+ val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration)
+ assert(fs.exists(new Path(model.getCheckpointFiles.head)))
+ model.deleteCheckpointFiles()
+ assert(model.getCheckpointFiles.isEmpty)
+ }
+
+ test("EM LDA checkpointing: remove last checkpoint") {
+ // Checkpoint dir is set by MLlibTestSparkContext
+ val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3).setCheckpointInterval(1)
+ .setKeepLastCheckpoint(false)
+ val model_ = lda.fit(dataset)
+ assert(model_.isInstanceOf[DistributedLDAModel])
+ val model = model_.asInstanceOf[DistributedLDAModel]
+
+ assert(model.getCheckpointFiles.isEmpty)
+ }
+
+ test("EM LDA disable checkpointing") {
+ // Checkpoint dir is set by MLlibTestSparkContext
+ val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3)
+ .setCheckpointInterval(-1)
+ val model_ = lda.fit(dataset)
+ assert(model_.isInstanceOf[DistributedLDAModel])
+ val model = model_.asInstanceOf[DistributedLDAModel]
+
+ assert(model.getCheckpointFiles.isEmpty)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
index 04f165c5f1..7641e3b8cf 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
@@ -59,14 +59,15 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
(0, split("a b c d e"),
Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))),
(1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))),
- (2, split("c"), Vectors.sparse(5, Seq((2, 1.0)))),
- (3, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0)))))
+ (2, split("c c"), Vectors.sparse(5, Seq((2, 2.0)))),
+ (3, split("d"), Vectors.sparse(5, Seq((3, 1.0)))),
+ (4, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0)))))
).toDF("id", "words", "expected")
val cv = new CountVectorizer()
.setInputCol("words")
.setOutputCol("features")
.fit(df)
- assert(cv.vocabulary === Array("a", "b", "c", "d", "e"))
+ assert(cv.vocabulary.toSet === Set("a", "b", "c", "d", "e"))
cv.transform(df).select("features", "expected").collect().foreach {
case Row(features: Vector, expected: Vector) =>
@@ -168,21 +169,34 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
}
}
- test("CountVectorizerModel with binary") {
+ test("CountVectorizerModel and CountVectorizer with binary") {
val df = sqlContext.createDataFrame(Seq(
- (0, split("a a a b b c"), Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0)))),
+ (0, split("a a a a b b b b c d"),
+ Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))),
(1, split("c c c"), Vectors.sparse(4, Seq((2, 1.0)))),
(2, split("a"), Vectors.sparse(4, Seq((0, 1.0))))
)).toDF("id", "words", "expected")
- val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
+ // CountVectorizer test
+ val cv = new CountVectorizer()
.setInputCol("words")
.setOutputCol("features")
.setBinary(true)
+ .fit(df)
cv.transform(df).select("features", "expected").collect().foreach {
case Row(features: Vector, expected: Vector) =>
assert(features ~== expected absTol 1e-14)
}
+
+ // CountVectorizerModel test
+ val cv2 = new CountVectorizerModel(cv.vocabulary)
+ .setInputCol("words")
+ .setOutputCol("features")
+ .setBinary(true)
+ cv2.transform(df).select("features", "expected").collect().foreach {
+ case Row(features: Vector, expected: Vector) =>
+ assert(features ~== expected absTol 1e-14)
+ }
}
test("CountVectorizer read/write") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
index 0dcd0f4946..addd733c20 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
@@ -46,12 +46,30 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
require(attrGroup.numAttributes === Some(n))
val features = output.select("features").first().getAs[Vector](0)
// Assume perfect hash on "a", "b", "c", and "d".
- def idx(any: Any): Int = Utils.nonNegativeMod(any.##, n)
+ def idx: Any => Int = featureIdx(n)
val expected = Vectors.sparse(n,
Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0)))
assert(features ~== expected absTol 1e-14)
}
+ test("applying binary term freqs") {
+ val df = sqlContext.createDataFrame(Seq(
+ (0, "a a b c c c".split(" ").toSeq)
+ )).toDF("id", "words")
+ val n = 100
+ val hashingTF = new HashingTF()
+ .setInputCol("words")
+ .setOutputCol("features")
+ .setNumFeatures(n)
+ .setBinary(true)
+ val output = hashingTF.transform(df)
+ val features = output.select("features").first().getAs[Vector](0)
+ def idx: Any => Int = featureIdx(n) // Assume perfect hash on input features
+ val expected = Vectors.sparse(n,
+ Seq((idx("a"), 1.0), (idx("b"), 1.0), (idx("c"), 1.0)))
+ assert(features ~== expected absTol 1e-14)
+ }
+
test("read/write") {
val t = new HashingTF()
.setInputCol("myInputCol")
@@ -59,4 +77,8 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
.setNumFeatures(10)
testDefaultReadWrite(t)
}
+
+ private def featureIdx(numFeatures: Int)(term: Any): Int = {
+ Utils.nonNegativeMod(term.##, numFeatures)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
index 58fda29aa1..e4e15f4331 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
@@ -22,7 +22,7 @@ import scala.beans.BeanInfo
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
@BeanInfo
case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String])
@@ -92,7 +92,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe
object NGramSuite extends SparkFunSuite {
- def testNGram(t: NGram, dataset: DataFrame): Unit = {
+ def testNGram(t: NGram, dataset: Dataset[_]): Unit = {
t.transform(dataset)
.select("nGrams", "wantedNGrams")
.collect()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index 25fabf64d5..8895d630a0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -17,78 +17,60 @@
package org.apache.spark.ml.feature
-import org.apache.spark.{SparkContext, SparkFunSuite}
-import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
+import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.functions.udf
class QuantileDiscretizerSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
- import org.apache.spark.ml.feature.QuantileDiscretizerSuite._
-
- test("Test quantile discretizer") {
- checkDiscretizedData(sc,
- Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
- 10,
- Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
- Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity"))
-
- checkDiscretizedData(sc,
- Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
- 4,
- Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
- Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity"))
-
- checkDiscretizedData(sc,
- Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
- 3,
- Array[Double](0, 1, 2, 2, 2, 2, 2, 2, 2),
- Array("-Infinity, 2.0", "2.0, 3.0", "3.0, Infinity"))
+ test("Test observed number of buckets and their sizes match expected values") {
+ val sqlCtx = SQLContext.getOrCreate(sc)
+ import sqlCtx.implicits._
- checkDiscretizedData(sc,
- Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
- 2,
- Array[Double](0, 1, 1, 1, 1, 1, 1, 1, 1),
- Array("-Infinity, 2.0", "2.0, Infinity"))
+ val datasetSize = 100000
+ val numBuckets = 5
+ val df = sc.parallelize(1.0 to datasetSize by 1.0).map(Tuple1.apply).toDF("input")
+ val discretizer = new QuantileDiscretizer()
+ .setInputCol("input")
+ .setOutputCol("result")
+ .setNumBuckets(numBuckets)
+ val result = discretizer.fit(df).transform(df)
- }
+ val observedNumBuckets = result.select("result").distinct.count
+ assert(observedNumBuckets === numBuckets,
+ "Observed number of buckets does not equal expected number of buckets.")
- test("Test getting splits") {
- val splitTestPoints = Array(
- Array[Double]() -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
- Array(Double.NegativeInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
- Array(Double.PositiveInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
- Array(Double.NegativeInfinity, Double.PositiveInfinity)
- -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
- Array(0.0) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
- Array(1.0) -> Array(Double.NegativeInfinity, 1, Double.PositiveInfinity),
- Array(0.0, 1.0) -> Array(Double.NegativeInfinity, 0, 1, Double.PositiveInfinity)
- )
- for ((ori, res) <- splitTestPoints) {
- assert(QuantileDiscretizer.getSplits(ori) === res, "Returned splits are invalid.")
+ val relativeError = discretizer.getRelativeError
+ val isGoodBucket = udf {
+ (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize)
}
+ val numGoodBuckets = result.groupBy("result").count.filter(isGoodBucket($"count")).count
+ assert(numGoodBuckets === numBuckets,
+ "Bucket sizes are not within expected relative error tolerance.")
}
- test("Test splits on dataset larger than minSamplesRequired") {
+ test("Test transform method on unseen data") {
val sqlCtx = SQLContext.getOrCreate(sc)
import sqlCtx.implicits._
- val datasetSize = QuantileDiscretizer.minSamplesRequired + 1
- val numBuckets = 5
- val df = sc.parallelize((1.0 to datasetSize by 1.0).map(Tuple1.apply)).toDF("input")
+ val trainDF = sc.parallelize(1.0 to 100.0 by 1.0).map(Tuple1.apply).toDF("input")
+ val testDF = sc.parallelize(-10.0 to 110.0 by 1.0).map(Tuple1.apply).toDF("input")
val discretizer = new QuantileDiscretizer()
.setInputCol("input")
.setOutputCol("result")
- .setNumBuckets(numBuckets)
- .setSeed(1)
+ .setNumBuckets(5)
- val result = discretizer.fit(df).transform(df)
- val observedNumBuckets = result.select("result").distinct.count
+ val result = discretizer.fit(trainDF).transform(testDF)
+ val firstBucketSize = result.filter(result("result") === 0.0).count
+ val lastBucketSize = result.filter(result("result") === 4.0).count
- assert(observedNumBuckets === numBuckets,
- "Observed number of buckets does not equal expected number of buckets.")
+ assert(firstBucketSize === 30L,
+ s"Size of first bucket ${firstBucketSize} did not equal expected value of 30.")
+ assert(lastBucketSize === 31L,
+ s"Size of last bucket ${lastBucketSize} did not equal expected value of 31.")
}
test("read/write") {
@@ -98,34 +80,17 @@ class QuantileDiscretizerSuite
.setNumBuckets(6)
testDefaultReadWrite(t)
}
-}
-
-private object QuantileDiscretizerSuite extends SparkFunSuite {
- def checkDiscretizedData(
- sc: SparkContext,
- data: Array[Double],
- numBucket: Int,
- expectedResult: Array[Double],
- expectedAttrs: Array[String]): Unit = {
+ test("Verify resulting model has parent") {
val sqlCtx = SQLContext.getOrCreate(sc)
import sqlCtx.implicits._
- val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input")
- val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result")
- .setNumBuckets(numBucket).setSeed(1)
+ val df = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("input")
+ val discretizer = new QuantileDiscretizer()
+ .setInputCol("input")
+ .setOutputCol("result")
+ .setNumBuckets(5)
val model = discretizer.fit(df)
assert(model.hasParent)
- val result = model.transform(df)
-
- val transformedFeatures = result.select("result").collect()
- .map { case Row(transformedFeature: Double) => transformedFeature }
- val transformedAttrs = Attribute.fromStructField(result.schema("result"))
- .asInstanceOf[NominalAttribute].values.get
-
- assert(transformedFeatures === expectedResult,
- "Transformed features do not equal expected features.")
- assert(transformedAttrs === expectedAttrs,
- "Transformed attributes do not equal expected attributes.")
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
index 553e0b8702..e213e17d0d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.types.{LongType, StructField, StructType}
class SQLTransformerSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -49,4 +50,13 @@ class SQLTransformerSuite
.setStatement("select * from __THIS__")
testDefaultReadWrite(t)
}
+
+ test("transformSchema") {
+ val df = sqlContext.range(10)
+ val outputSchema = new SQLTransformer()
+ .setStatement("SELECT id + 1 AS id1 FROM __THIS__")
+ .transformSchema(df.schema)
+ val expected = StructType(Seq(StructField("id1", LongType, nullable = false)))
+ assert(outputSchema === expected)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
index a5b24c1856..3505befdf8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
@@ -20,10 +20,10 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
object StopWordsRemoverSuite extends SparkFunSuite {
- def testStopWordsRemover(t: StopWordsRemover, dataset: DataFrame): Unit = {
+ def testStopWordsRemover(t: StopWordsRemover, dataset: Dataset[_]): Unit = {
t.transform(dataset)
.select("filtered", "expected")
.collect()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 2c3255ef33..d0f3cdc841 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -115,7 +115,7 @@ class StringIndexerSuite
.setInputCol("label")
.setOutputCol("labelIndex")
val df = sqlContext.range(0L, 10L).toDF()
- assert(indexerModel.transform(df).eq(df))
+ assert(indexerModel.transform(df).collect().toSet === df.collect().toSet)
}
test("StringIndexerModel can't overwrite output column") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
index 36e8e5d868..299f6223b2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
@BeanInfo
case class TokenizerTestData(rawText: String, wantedTokens: Array[String])
@@ -106,7 +106,7 @@ class RegexTokenizerSuite
object RegexTokenizerSuite extends SparkFunSuite {
- def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = {
+ def testRegexTokenizer(t: RegexTokenizer, dataset: Dataset[_]): Unit = {
t.transform(dataset)
.select("tokens", "wantedTokens")
.collect()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index dbd752d2aa..76891ad562 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -33,6 +33,7 @@ class AFTSurvivalRegressionSuite
@transient var datasetUnivariate: DataFrame = _
@transient var datasetMultivariate: DataFrame = _
+ @transient var datasetUnivariateScaled: DataFrame = _
override def beforeAll(): Unit = {
super.beforeAll()
@@ -42,6 +43,11 @@ class AFTSurvivalRegressionSuite
datasetMultivariate = sqlContext.createDataFrame(
sc.parallelize(generateAFTInput(
2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0)))
+ datasetUnivariateScaled = sqlContext.createDataFrame(
+ sc.parallelize(generateAFTInput(
+ 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0)).map { x =>
+ AFTPoint(Vectors.dense(x.features(0) * 1.0E3), x.label, x.censor)
+ })
}
/**
@@ -347,6 +353,31 @@ class AFTSurvivalRegressionSuite
}
}
+ test("should support all NumericType labels") {
+ val aft = new AFTSurvivalRegression().setMaxIter(1)
+ MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression](
+ aft, isClassification = false, sqlContext) { (expected, actual) =>
+ assert(expected.intercept === actual.intercept)
+ assert(expected.coefficients === actual.coefficients)
+ }
+ }
+
+ test("numerical stability of standardization") {
+ val trainer = new AFTSurvivalRegression()
+ val model1 = trainer.fit(datasetUnivariate)
+ val model2 = trainer.fit(datasetUnivariateScaled)
+
+ /**
+ * During training we standardize the dataset first, so no matter how we multiple
+ * a scaling factor into the dataset, the convergence rate should be the same,
+ * and the coefficients should equal to the original coefficients multiple by
+ * the scaling factor. It will have no effect on the intercept and scale.
+ */
+ assert(model1.coefficients(0) ~== model2.coefficients(0) * 1.0E3 absTol 0.01)
+ assert(model1.intercept ~== model2.intercept absTol 0.01)
+ assert(model1.scale ~== model2.scale absTol 0.01)
+ }
+
test("read/write") {
def checkModelData(
model: AFTSurvivalRegressionModel,
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 662e3fc679..e9fb2677b2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -117,6 +117,14 @@ class DecisionTreeRegressorSuite
assert(importances.toArray.forall(_ >= 0.0))
}
+ test("should support all NumericType labels and not support other types") {
+ val dt = new DecisionTreeRegressor().setMaxDepth(1)
+ MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor](
+ dt, isClassification = false, sqlContext) { (expected, actual) =>
+ TreeTests.checkEqual(expected, actual)
+ }
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 84148a8a4a..216377959e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
@@ -29,11 +29,11 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.util.Utils
-
/**
* Test suite for [[GBTRegressor]].
*/
-class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
+class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
+ with DefaultReadWriteTest {
import GBTRegressorSuite.compareAPIs
@@ -54,7 +54,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2)
}
- test("Regression with continuous features: SquaredError") {
+ test("Regression with continuous features") {
val categoricalFeatures = Map.empty[Int, Int]
GBTRegressor.supportedLossTypes.foreach { loss =>
testCombinations.foreach {
@@ -110,7 +110,14 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
sc.checkpointDir = None
Utils.deleteRecursively(tempDir)
+ }
+ test("should support all NumericType labels and not support other types") {
+ val gbt = new GBTRegressor().setMaxDepth(1)
+ MLTestingUtils.checkNumericTypes[GBTRegressionModel, GBTRegressor](
+ gbt, isClassification = false, sqlContext) { (expected, actual) =>
+ TreeTests.checkEqual(expected, actual)
+ }
}
// TODO: Reinstate test once runWithValidation is implemented SPARK-7132
@@ -132,30 +139,48 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
*/
/////////////////////////////////////////////////////////////////////////////
+ // Tests of feature importance
+ /////////////////////////////////////////////////////////////////////////////
+ test("Feature importance with toy data") {
+ val gbt = new GBTRegressor()
+ .setMaxDepth(3)
+ .setMaxIter(5)
+ .setSubsamplingRate(1.0)
+ .setStepSize(0.5)
+ .setSeed(123)
+
+ // In this data, feature 1 is very important.
+ val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+ val categoricalFeatures = Map.empty[Int, Int]
+ val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
+
+ val importances = gbt.fit(df).featureImportances
+ val mostImportantFeature = importances.argmax
+ assert(mostImportantFeature === 1)
+ assert(importances.toArray.sum === 1.0)
+ assert(importances.toArray.forall(_ >= 0.0))
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: Reinstate test once save/load are implemented SPARK-6725
- /*
test("model save/load") {
- val tempDir = Utils.createTempDir()
- val path = tempDir.toURI.toString
-
- val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray
- val treeWeights = Array(0.1, 0.3, 1.1)
- val oldModel = new OldGBTModel(OldAlgo.Regression, trees, treeWeights)
- val newModel = GBTRegressionModel.fromOld(oldModel)
-
- // Save model, load it back, and compare.
- try {
- newModel.save(sc, path)
- val sameNewModel = GBTRegressionModel.load(sc, path)
- TreeTests.checkEqual(newModel, sameNewModel)
- } finally {
- Utils.deleteRecursively(tempDir)
+ def checkModelData(
+ model: GBTRegressionModel,
+ model2: GBTRegressionModel): Unit = {
+ TreeTests.checkEqual(model, model2)
+ assert(model.numFeatures === model2.numFeatures)
}
+
+ val gbt = new GBTRegressor()
+ val rdd = TreeTests.getTreeReadWriteData(sc)
+
+ val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "squared")
+ val continuousData: DataFrame =
+ TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
+ testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData)
}
- */
}
private object GBTRegressorSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index 4ebdbf2213..3ecc210abd 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -626,6 +626,7 @@ class GeneralizedLinearRegressionSuite
assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR)
assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR)
assert(summary.aic ~== aicR absTol 1E-3)
+ assert(summary.solver === "irls")
}
test("glm summary: binomial family with weight") {
@@ -739,6 +740,7 @@ class GeneralizedLinearRegressionSuite
assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR)
assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR)
assert(summary.aic ~== aicR absTol 1E-3)
+ assert(summary.solver === "irls")
}
test("glm summary: poisson family with weight") {
@@ -855,6 +857,7 @@ class GeneralizedLinearRegressionSuite
assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR)
assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR)
assert(summary.aic ~== aicR absTol 1E-3)
+ assert(summary.solver === "irls")
}
test("glm summary: gamma family with weight") {
@@ -968,6 +971,7 @@ class GeneralizedLinearRegressionSuite
assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR)
assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR)
assert(summary.aic ~== aicR absTol 1E-3)
+ assert(summary.solver === "irls")
}
test("read/write") {
@@ -982,6 +986,24 @@ class GeneralizedLinearRegressionSuite
testEstimatorAndModelReadWrite(glr, datasetPoissonLog,
GeneralizedLinearRegressionSuite.allParamSettings, checkModelData)
}
+
+ test("should support all NumericType labels and not support other types") {
+ val glr = new GeneralizedLinearRegression().setMaxIter(1)
+ MLTestingUtils.checkNumericTypes[
+ GeneralizedLinearRegressionModel, GeneralizedLinearRegression](
+ glr, isClassification = false, sqlContext) { (expected, actual) =>
+ assert(expected.intercept === actual.intercept)
+ assert(expected.coefficients === actual.coefficients)
+ }
+ }
+
+ test("glm accepts Dataset[LabeledPoint]") {
+ val context = sqlContext
+ import context.implicits._
+ new GeneralizedLinearRegression()
+ .setFamily("gaussian")
+ .fit(datasetGaussianIdentity.as[LabeledPoint])
+ }
}
object GeneralizedLinearRegressionSuite {
@@ -1023,7 +1045,7 @@ object GeneralizedLinearRegressionSuite {
generator.setSeed(seed)
(0 until nPoints).map { _ =>
- val features = Vectors.dense(coefficients.indices.map { rndElement(_) }.toArray)
+ val features = Vectors.dense(coefficients.indices.map(rndElement).toArray)
val eta = BLAS.dot(Vectors.dense(coefficients), features) + intercept
val mu = link match {
case "identity" => eta
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
index b8874b4cd3..3a10ad7ed0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
@@ -180,6 +180,15 @@ class IsotonicRegressionSuite
testEstimatorAndModelReadWrite(ir, dataset, IsotonicRegressionSuite.allParamSettings,
checkModelData)
}
+
+ test("should support all NumericType labels and not support other types") {
+ val ir = new IsotonicRegression()
+ MLTestingUtils.checkNumericTypes[IsotonicRegressionModel, IsotonicRegression](
+ ir, isClassification = false, sqlContext) { (expected, actual) =>
+ assert(expected.boundaries === actual.boundaries)
+ assert(expected.predictions === actual.predictions)
+ }
+ }
}
object IsotonicRegressionSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index bd45d21e8d..eb19d13093 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -61,9 +61,9 @@ class LinearRegressionSuite
val featureSize = 4100
datasetWithSparseFeature = sqlContext.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput(
- intercept = 0.0, weights = Seq.fill(featureSize)(r.nextDouble).toArray,
- xMean = Seq.fill(featureSize)(r.nextDouble).toArray,
- xVariance = Seq.fill(featureSize)(r.nextDouble).toArray, nPoints = 200,
+ intercept = 0.0, weights = Seq.fill(featureSize)(r.nextDouble()).toArray,
+ xMean = Seq.fill(featureSize)(r.nextDouble()).toArray,
+ xVariance = Seq.fill(featureSize)(r.nextDouble()).toArray, nPoints = 200,
seed, eps = 0.1, sparsity = 0.7), 2))
/*
@@ -687,7 +687,7 @@ class LinearRegressionSuite
// Validate that we re-insert a prediction column for evaluation
val modelNoPredictionColFieldNames
= modelNoPredictionCol.summary.predictions.schema.fieldNames
- assert((datasetWithDenseFeature.schema.fieldNames.toSet).subsetOf(
+ assert(datasetWithDenseFeature.schema.fieldNames.toSet.subsetOf(
modelNoPredictionColFieldNames.toSet))
assert(modelNoPredictionColFieldNames.exists(s => s.startsWith("prediction_")))
@@ -759,7 +759,7 @@ class LinearRegressionSuite
.sliding(2)
.forall(x => x(0) >= x(1)))
} else {
- // To clalify that the normal solver is used here.
+ // To clarify that the normal solver is used here.
assert(model.summary.objectiveHistory.length == 1)
assert(model.summary.objectiveHistory(0) == 0.0)
val devianceResidualsR = Array(-0.47082, 0.34635)
@@ -1006,6 +1006,15 @@ class LinearRegressionSuite
testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings,
checkModelData)
}
+
+ test("should support all NumericType labels and not support other types") {
+ val lr = new LinearRegression().setMaxIter(1)
+ MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression](
+ lr, isClassification = false, sqlContext) { (expected, actual) =>
+ assert(expected.intercept === actual.intercept)
+ assert(expected.coefficients === actual.coefficients)
+ }
+ }
}
object LinearRegressionSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index 6be0c8bca0..ca400e1914 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -30,7 +30,8 @@ import org.apache.spark.sql.DataFrame
/**
* Test suite for [[RandomForestRegressor]].
*/
-class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
+class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
+ with DefaultReadWriteTest{
import RandomForestRegressorSuite.compareAPIs
@@ -94,30 +95,35 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
assert(importances.toArray.forall(_ >= 0.0))
}
+ test("should support all NumericType labels and not support other types") {
+ val rf = new RandomForestRegressor().setMaxDepth(1)
+ MLTestingUtils.checkNumericTypes[RandomForestRegressionModel, RandomForestRegressor](
+ rf, isClassification = false, sqlContext) { (expected, actual) =>
+ TreeTests.checkEqual(expected, actual)
+ }
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
- // TODO: Reinstate test once save/load are implemented SPARK-6725
- /*
- test("model save/load") {
- val tempDir = Utils.createTempDir()
- val path = tempDir.toURI.toString
-
- val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray
- val oldModel = new OldRandomForestModel(OldAlgo.Regression, trees)
- val newModel = RandomForestRegressionModel.fromOld(oldModel)
-
- // Save model, load it back, and compare.
- try {
- newModel.save(sc, path)
- val sameNewModel = RandomForestRegressionModel.load(sc, path)
- TreeTests.checkEqual(newModel, sameNewModel)
- } finally {
- Utils.deleteRecursively(tempDir)
+ test("read/write") {
+ def checkModelData(
+ model: RandomForestRegressionModel,
+ model2: RandomForestRegressionModel): Unit = {
+ TreeTests.checkEqual(model, model2)
+ assert(model.numFeatures === model2.numFeatures)
}
+
+ val rf = new RandomForestRegressor().setNumTrees(2)
+ val rdd = TreeTests.getTreeReadWriteData(sc)
+
+ val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "variance")
+
+ val continuousData: DataFrame =
+ TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
+ testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData)
}
- */
}
private object RandomForestRegressorSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
index 114a238462..0bd14978b2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala
@@ -28,8 +28,9 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.SaveMode
import org.apache.spark.util.Utils
+
class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
- var tempDir: File = _
+ // Path for dataset
var path: String = _
override def beforeAll(): Unit = {
@@ -40,15 +41,15 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
|0
|0 2:4.0 4:5.0 6:6.0
""".stripMargin
- tempDir = Utils.createTempDir()
- val file = new File(tempDir, "part-00000")
+ val dir = Utils.createDirectory(tempDir.getCanonicalPath, "data")
+ val file = new File(dir, "part-00000")
Files.write(lines, file, StandardCharsets.UTF_8)
- path = tempDir.toURI.toString
+ path = dir.toURI.toString
}
override def afterAll(): Unit = {
try {
- Utils.deleteRecursively(tempDir)
+ Utils.deleteRecursively(new File(path))
} finally {
super.afterAll()
}
@@ -86,7 +87,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
test("write libsvm data and read it again") {
val df = sqlContext.read.format("libsvm").load(path)
- val tempDir2 = Utils.createTempDir()
+ val tempDir2 = new File(tempDir, "read_write_test")
val writepath = tempDir2.toURI.toString
// TODO: Remove requirement to coalesce by supporting multiple reads.
df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writepath)
@@ -99,7 +100,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
test("write libsvm data failed due to invalid schema") {
val df = sqlContext.read.format("text").load(path)
- val e = intercept[SparkException] {
+ intercept[SparkException] {
df.write.format("libsvm").save(path + "_2")
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala
index 9d756da410..77ab3d8bb7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.tree.impl
+package org.apache.spark.ml.tree.impl
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.tree.EnsembleTestHelper
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala
new file mode 100644
index 0000000000..fecf372c3d
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree.impl
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.internal.Logging
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{GradientBoostedTreesSuite => OldGBTSuite}
+import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.impurity.Variance
+import org.apache.spark.mllib.tree.loss.{AbsoluteError, LogLoss, SquaredError}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+/**
+ * Test suite for [[GradientBoostedTrees]].
+ */
+class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
+
+ test("runWithValidation stops early and performs better on a validation dataset") {
+ // Set numIterations large enough so that it stops early.
+ val numIterations = 20
+ val trainRdd = sc.parallelize(OldGBTSuite.trainData, 2)
+ val validateRdd = sc.parallelize(OldGBTSuite.validateData, 2)
+ val trainDF = sqlContext.createDataFrame(trainRdd)
+ val validateDF = sqlContext.createDataFrame(validateRdd)
+
+ val algos = Array(Regression, Regression, Classification)
+ val losses = Array(SquaredError, AbsoluteError, LogLoss)
+ algos.zip(losses).foreach { case (algo, loss) =>
+ val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
+ categoricalFeaturesInfo = Map.empty)
+ val boostingStrategy =
+ new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
+ val (validateTrees, validateTreeWeights) = GradientBoostedTrees
+ .runWithValidation(trainRdd, validateRdd, boostingStrategy, 42L)
+ val numTrees = validateTrees.length
+ assert(numTrees !== numIterations)
+
+ // Test that it performs better on the validation dataset.
+ val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, 42L)
+ val (errorWithoutValidation, errorWithValidation) = {
+ if (algo == Classification) {
+ val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
+ (GradientBoostedTrees.computeError(remappedRdd, trees, treeWeights, loss),
+ GradientBoostedTrees.computeError(remappedRdd, validateTrees,
+ validateTreeWeights, loss))
+ } else {
+ (GradientBoostedTrees.computeError(validateRdd, trees, treeWeights, loss),
+ GradientBoostedTrees.computeError(validateRdd, validateTrees,
+ validateTreeWeights, loss))
+ }
+ }
+ assert(errorWithValidation <= errorWithoutValidation)
+
+ // Test that results from evaluateEachIteration comply with runWithValidation.
+ // Note that convergenceTol is set to 0.0
+ val evaluationArray = GradientBoostedTrees
+ .evaluateEachIteration(validateRdd, trees, treeWeights, loss, algo)
+ assert(evaluationArray.length === numIterations)
+ assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
+ var i = 1
+ while (i < numTrees) {
+ assert(evaluationArray(i) <= evaluationArray(i - 1))
+ i += 1
+ }
+ }
+ }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index 361366fde7..6db9ce150d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -26,7 +26,6 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy}
-import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
@@ -328,7 +327,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
case n: InternalNode => n.split match {
case s: CategoricalSplit =>
assert(s.leftCategories === Array(1.0))
+ case _ => throw new AssertionError("model.rootNode.split was not a CategoricalSplit")
}
+ case _ => throw new AssertionError("model.rootNode was not an InternalNode")
}
}
@@ -353,6 +354,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(n.leftChild.isInstanceOf[InternalNode])
assert(n.rightChild.isInstanceOf[InternalNode])
Array(n.leftChild.asInstanceOf[InternalNode], n.rightChild.asInstanceOf[InternalNode])
+ case _ => throw new AssertionError("rootNode was not an InternalNode")
}
// Single group second level tree construction.
@@ -424,12 +426,48 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
(math.log(numFeatures) / math.log(2)).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt)
+ val realStrategies = Array(".1", ".10", "0.10", "0.1", "0.9", "1.0")
+ for (strategy <- realStrategies) {
+ val expected = (strategy.toDouble * numFeatures).ceil.toInt
+ checkFeatureSubsetStrategy(numTrees = 1, strategy, expected)
+ }
+
+ val integerStrategies = Array("1", "10", "100", "1000", "10000")
+ for (strategy <- integerStrategies) {
+ val expected = if (strategy.toInt < numFeatures) strategy.toInt else numFeatures
+ checkFeatureSubsetStrategy(numTrees = 1, strategy, expected)
+ }
+
+ val invalidStrategies = Array("-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0")
+ for (invalidStrategy <- invalidStrategies) {
+ intercept[MatchError]{
+ val metadata =
+ DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 1, invalidStrategy)
+ }
+ }
+
checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures)
checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 2, "log2",
(math.log(numFeatures) / math.log(2)).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt)
+
+ for (strategy <- realStrategies) {
+ val expected = (strategy.toDouble * numFeatures).ceil.toInt
+ checkFeatureSubsetStrategy(numTrees = 2, strategy, expected)
+ }
+
+ for (strategy <- integerStrategies) {
+ val expected = if (strategy.toInt < numFeatures) strategy.toInt else numFeatures
+ checkFeatureSubsetStrategy(numTrees = 2, strategy, expected)
+ }
+ for (invalidStrategy <- invalidStrategies) {
+ intercept[MatchError]{
+ val metadata =
+ DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 2, invalidStrategy)
+ }
+ }
}
test("Binary classification with continuous features: subsampling features") {
@@ -471,7 +509,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// Test feature importance computed at different subtrees.
def testNode(node: Node, expected: Map[Int, Double]): Unit = {
val map = new OpenHashMap[Int, Double]()
- RandomForest.computeFeatureImportance(node, map)
+ TreeEnsembleModel.computeFeatureImportance(node, map)
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
}
@@ -493,7 +531,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 3)
.asInstanceOf[DecisionTreeModel]
}
- val importances: Vector = RandomForest.featureImportances(trees, 2)
+ val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2)
val tree2norm = feature0importance + feature1importance
val expected = Vectors.dense((1.0 + feature0importance / tree2norm) / 2.0,
(feature1importance / tree2norm) / 2.0)
@@ -504,7 +542,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val map = new OpenHashMap[Int, Double]()
map(0) = 1.0
map(2) = 2.0
- RandomForest.normalizeMapValues(map)
+ TreeEnsembleModel.normalizeMapValues(map)
val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
index 12808b0305..b650a9f092 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
@@ -74,6 +74,24 @@ private[ml] object TreeTests extends SparkFunSuite {
}
/**
+ * Set label metadata (particularly the number of classes) on a DataFrame.
+ * @param data Dataset. Categorical features and labels must already have 0-based indices.
+ * This must be non-empty.
+ * @param numClasses Number of classes label can take. If 0, mark as continuous.
+ * @param labelColName Name of the label column on which to set the metadata.
+ * @return DataFrame with metadata
+ */
+ def setMetadata(data: DataFrame, numClasses: Int, labelColName: String): DataFrame = {
+ val labelAttribute = if (numClasses == 0) {
+ NumericAttribute.defaultAttr.withName(labelColName)
+ } else {
+ NominalAttribute.defaultAttr.withName(labelColName).withNumValues(numClasses)
+ }
+ val labelMetadata = labelAttribute.toMetadata()
+ data.select(data("features"), data(labelColName).as(labelColName, labelMetadata))
+ }
+
+ /**
* Check if the two trees are exactly the same.
* Note: I hesitate to override Node.equals since it could cause problems if users
* make mistakes such as creating loops of Nodes.
@@ -113,7 +131,7 @@ private[ml] object TreeTests extends SparkFunSuite {
* Check if the two models are exactly the same.
* If the models are not equal, this throws an exception.
*/
- def checkEqual(a: TreeEnsembleModel, b: TreeEnsembleModel): Unit = {
+ def checkEqual[M <: DecisionTreeModel](a: TreeEnsembleModel[M], b: TreeEnsembleModel[M]): Unit = {
try {
a.trees.zip(b.trees).foreach { case (treeA, treeB) =>
TreeTests.checkEqual(treeA, treeB)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 7af3c6d6ed..3e734aabc5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -29,13 +29,13 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.{StructField, StructType}
class CrossValidatorSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
- @transient var dataset: DataFrame = _
+ @transient var dataset: Dataset[_] = _
override def beforeAll(): Unit = {
super.beforeAll()
@@ -311,7 +311,7 @@ object CrossValidatorSuite extends SparkFunSuite {
class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
- override def fit(dataset: DataFrame): MyModel = {
+ override def fit(dataset: Dataset[_]): MyModel = {
throw new UnsupportedOperationException
}
@@ -325,7 +325,7 @@ object CrossValidatorSuite extends SparkFunSuite {
class MyEvaluator extends Evaluator {
- override def evaluate(dataset: DataFrame): Double = {
+ override def evaluate(dataset: Dataset[_]): Double = {
throw new UnsupportedOperationException
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index cf8dcefebc..dbee47c847 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -19,17 +19,20 @@ package org.apache.spark.ml.tuning
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.classification.LogisticRegression
+import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.ml.regression.LinearRegression
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
-class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext {
+class TrainValidationSplitSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("train validation with logistic regression") {
val dataset = sqlContext.createDataFrame(
sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
@@ -45,6 +48,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext
.setEstimatorParamMaps(lrParamMaps)
.setEvaluator(eval)
.setTrainRatio(0.5)
+ .setSeed(42L)
val cvModel = cv.fit(dataset)
val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
assert(cv.getTrainRatio === 0.5)
@@ -69,6 +73,7 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext
.setEstimatorParamMaps(lrParamMaps)
.setEvaluator(eval)
.setTrainRatio(0.5)
+ .setSeed(42L)
val cvModel = cv.fit(dataset)
val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression]
assert(parent.getRegParam === 0.001)
@@ -105,6 +110,46 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext
cv.transformSchema(new StructType())
}
}
+
+ test("read/write: TrainValidationSplit") {
+ val lr = new LogisticRegression().setMaxIter(3)
+ val evaluator = new BinaryClassificationEvaluator()
+ val paramMaps = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.1, 0.2))
+ .build()
+ val tvs = new TrainValidationSplit()
+ .setEstimator(lr)
+ .setEvaluator(evaluator)
+ .setTrainRatio(0.5)
+ .setEstimatorParamMaps(paramMaps)
+ .setSeed(42L)
+
+ val tvs2 = testDefaultReadWrite(tvs, testParams = false)
+
+ assert(tvs.getTrainRatio === tvs2.getTrainRatio)
+ }
+
+ test("read/write: TrainValidationSplitModel") {
+ val lr = new LogisticRegression()
+ .setThreshold(0.6)
+ val lrModel = new LogisticRegressionModel(lr.uid, Vectors.dense(1.0, 2.0), 1.2)
+ .setThreshold(0.6)
+ val evaluator = new BinaryClassificationEvaluator()
+ val paramMaps = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.1, 0.2))
+ .build()
+ val tvs = new TrainValidationSplitModel("cvUid", lrModel, Array(0.3, 0.6))
+ tvs.set(tvs.estimator, lr)
+ .set(tvs.evaluator, evaluator)
+ .set(tvs.trainRatio, 0.5)
+ .set(tvs.estimatorParamMaps, paramMaps)
+ .set(tvs.seed, 42L)
+
+ val tvs2 = testDefaultReadWrite(tvs, testParams = false)
+
+ assert(tvs.getTrainRatio === tvs2.getTrainRatio)
+ assert(tvs.validationMetrics === tvs2.validationMetrics)
+ }
}
object TrainValidationSplitSuite {
@@ -113,7 +158,7 @@ object TrainValidationSplitSuite {
class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
- override def fit(dataset: DataFrame): MyModel = {
+ override def fit(dataset: Dataset[_]): MyModel = {
throw new UnsupportedOperationException
}
@@ -127,7 +172,7 @@ object TrainValidationSplitSuite {
class MyEvaluator extends Evaluator {
- override def evaluate(dataset: DataFrame): Double = {
+ override def evaluate(dataset: Dataset[_]): Double = {
throw new UnsupportedOperationException
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
index 16280473c6..7ebd7eb144 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
@@ -25,7 +25,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
@@ -98,7 +98,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
def testEstimatorAndModelReadWrite[
E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable](
estimator: E,
- dataset: DataFrame,
+ dataset: Dataset[_],
testParams: Map[String, Any],
checkModelData: (M, M) => Unit): Unit = {
// Set some Params to make sure set Params are serialized.
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
index d290cc9b06..8108460518 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
@@ -17,14 +17,96 @@
package org.apache.spark.ml.util
-import org.apache.spark.ml.Model
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.tree.impl.TreeTests
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
-object MLTestingUtils {
+object MLTestingUtils extends SparkFunSuite {
def checkCopy(model: Model[_]): Unit = {
val copied = model.copy(ParamMap.empty)
.asInstanceOf[Model[_]]
assert(copied.parent.uid == model.parent.uid)
assert(copied.parent == model.parent)
}
+
+ def checkNumericTypes[M <: Model[M], T <: Estimator[M]](
+ estimator: T,
+ isClassification: Boolean,
+ sqlContext: SQLContext)(check: (M, M) => Unit): Unit = {
+ val dfs = if (isClassification) {
+ genClassifDFWithNumericLabelCol(sqlContext)
+ } else {
+ genRegressionDFWithNumericLabelCol(sqlContext)
+ }
+ val expected = estimator.fit(dfs(DoubleType))
+ val actuals = dfs.keys.filter(_ != DoubleType).map(t => estimator.fit(dfs(t)))
+ actuals.foreach(actual => check(expected, actual))
+
+ val dfWithStringLabels = generateDFWithStringLabelCol(sqlContext)
+ val thrown = intercept[IllegalArgumentException] {
+ estimator.fit(dfWithStringLabels)
+ }
+ assert(thrown.getMessage contains
+ "Column label must be of type NumericType but was actually of type StringType")
+ }
+
+ def genClassifDFWithNumericLabelCol(
+ sqlContext: SQLContext,
+ labelColName: String = "label",
+ featuresColName: String = "features"): Map[NumericType, DataFrame] = {
+ val df = sqlContext.createDataFrame(Seq(
+ (0, Vectors.dense(0, 2, 3)),
+ (1, Vectors.dense(0, 3, 1)),
+ (0, Vectors.dense(0, 2, 2)),
+ (1, Vectors.dense(0, 3, 9)),
+ (0, Vectors.dense(0, 2, 6))
+ )).toDF(labelColName, featuresColName)
+
+ val types =
+ Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
+ types.map(t => t -> df.select(col(labelColName).cast(t), col(featuresColName)))
+ .map { case (t, d) => t -> TreeTests.setMetadata(d, 2, labelColName) }
+ .toMap
+ }
+
+ def genRegressionDFWithNumericLabelCol(
+ sqlContext: SQLContext,
+ labelColName: String = "label",
+ featuresColName: String = "features",
+ censorColName: String = "censor"): Map[NumericType, DataFrame] = {
+ val df = sqlContext.createDataFrame(Seq(
+ (0, Vectors.dense(0)),
+ (1, Vectors.dense(1)),
+ (2, Vectors.dense(2)),
+ (3, Vectors.dense(3)),
+ (4, Vectors.dense(4))
+ )).toDF(labelColName, featuresColName)
+
+ val types =
+ Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
+ types
+ .map(t => t -> df.select(col(labelColName).cast(t), col(featuresColName)))
+ .map { case (t, d) =>
+ t -> TreeTests.setMetadata(d, 0, labelColName).withColumn(censorColName, lit(0.0))
+ }
+ .toMap
+ }
+
+ def generateDFWithStringLabelCol(
+ sqlContext: SQLContext,
+ labelColName: String = "label",
+ featuresColName: String = "features",
+ censorColName: String = "censor"): DataFrame =
+ sqlContext.createDataFrame(Seq(
+ ("0", Vectors.dense(0, 2, 3), 0.0),
+ ("1", Vectors.dense(0, 3, 1), 1.0),
+ ("0", Vectors.dense(0, 2, 2), 0.0),
+ ("1", Vectors.dense(0, 3, 9), 1.0),
+ ("0", Vectors.dense(0, 2, 6), 0.0)
+ )).toDF(labelColName, featuresColName, censorColName)
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala
index 41b9d5c0d9..35f7932ae8 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -179,4 +180,21 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
}
+
+ test("BisectingKMeans model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ val points = (1 until 8).map(i => Vectors.dense(i))
+ val data = sc.parallelize(points, 2)
+ val model = new BisectingKMeans().run(data)
+ try {
+ model.save(sc, path)
+ val sameModel = BisectingKMeansModel.load(sc, path)
+ assert(model.k === sameModel.k)
+ model.clusterCenters.zip(sameModel.clusterCenters).foreach(c => c._1 === c._2)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
index cf279c0233..6c07e3a5ce 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -48,4 +49,15 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
val docs = sc.parallelize(localDocs, 2)
assert(hashingTF.transform(docs).collect().toSet === localDocs.map(hashingTF.transform).toSet)
}
+
+ test("applying binary term freqs") {
+ val hashingTF = new HashingTF(100).setBinary(true)
+ val doc = "a a b c c c".split(" ")
+ val n = hashingTF.numFeatures
+ val expected = Vectors.sparse(n, Seq(
+ (hashingTF.indexOf("a"), 1.0),
+ (hashingTF.indexOf("b"), 1.0),
+ (hashingTF.indexOf("c"), 1.0)))
+ assert(hashingTF.transform(doc) ~== expected absTol 1e-14)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
index a83e543859..6d8c7b47d8 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
@@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -357,6 +358,36 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
compareResults(expected, model.freqSequences.collect())
}
+ test("model save/load") {
+ val sequences = Seq(
+ Array(Array(1, 2), Array(3)),
+ Array(Array(1), Array(3, 2), Array(1, 2)),
+ Array(Array(1, 2), Array(5)),
+ Array(Array(6)))
+ val rdd = sc.parallelize(sequences, 2).cache()
+
+ val prefixSpan = new PrefixSpan()
+ .setMinSupport(0.5)
+ .setMaxPatternLength(5)
+ val model = prefixSpan.run(rdd)
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ try {
+ model.save(sc, path)
+ val newModel = PrefixSpanModel.load(sc, path)
+ val originalSet = model.freqSequences.collect().map { x =>
+ (x.sequence.map(_.toSet).toSeq, x.freq)
+ }.toSet
+ val newSet = newModel.freqSequences.collect().map { x =>
+ (x.sequence.map(_.toSet).toSeq, x.freq)
+ }.toSet
+ assert(originalSet === newSet)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+
private def compareResults[Item](
expectedValue: Array[(Array[Array[Item]], Long)],
actualValue: Array[PrefixSpan.FreqSequence[Item]]): Unit = {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
index a02b8c9635..e289724cda 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.linalg
import java.util.Random
+import breeze.linalg.{CSCMatrix, Matrix => BM}
import org.mockito.Mockito.when
import org.scalatest.mock.MockitoSugar._
import scala.collection.mutable.{Map => MutableMap}
@@ -150,6 +151,10 @@ class MatricesSuite extends SparkFunSuite {
sparseMat.update(0, 0, 10.0)
}
+ intercept[NoSuchElementException] {
+ sparseMat.update(2, 1, 10.0)
+ }
+
sparseMat.update(0, 1, 10.0)
assert(sparseMat(0, 1) === 10.0)
assert(sparseMat.values(2) === 10.0)
@@ -495,6 +500,17 @@ class MatricesSuite extends SparkFunSuite {
assert(sm1.numActives === 3)
}
+ test("fromBreeze with sparse matrix") {
+ // colPtr.last does NOT always equal to values.length in breeze SCSMatrix and
+ // invocation of compact() may be necessary. Refer to SPARK-11507
+ val bm1: BM[Double] = new CSCMatrix[Double](
+ Array(1.0, 1, 1), 3, 3, Array(0, 1, 2, 3), Array(0, 1, 2))
+ val bm2: BM[Double] = new CSCMatrix[Double](
+ Array(1.0, 2, 2, 4), 3, 3, Array(0, 0, 2, 4), Array(1, 2, 1, 2))
+ val sum = bm1 + bm2
+ Matrices.fromBreeze(sum)
+ }
+
test("row/col iterator") {
val dm = new DenseMatrix(3, 2, Array(0, 1, 2, 3, 4, 0))
val sm = dm.toSparse
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index bb1041b109..49cb7e1f24 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -20,12 +20,12 @@ package org.apache.spark.mllib.tree
import scala.collection.JavaConverters._
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.tree.impl.DecisionTreeMetadata
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.Strategy
-import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model._
import org.apache.spark.mllib.util.MLlibTestSparkContext
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index 747c267b4f..c61f89322d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -158,49 +158,6 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
}
}
- test("runWithValidation stops early and performs better on a validation dataset") {
- // Set numIterations large enough so that it stops early.
- val numIterations = 20
- val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2)
- val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2)
-
- val algos = Array(Regression, Regression, Classification)
- val losses = Array(SquaredError, AbsoluteError, LogLoss)
- algos.zip(losses).foreach { case (algo, loss) =>
- val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
- categoricalFeaturesInfo = Map.empty)
- val boostingStrategy =
- new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
- val gbtValidate = new GradientBoostedTrees(boostingStrategy, seed = 0)
- .runWithValidation(trainRdd, validateRdd)
- val numTrees = gbtValidate.numTrees
- assert(numTrees !== numIterations)
-
- // Test that it performs better on the validation dataset.
- val gbt = new GradientBoostedTrees(boostingStrategy, seed = 0).run(trainRdd)
- val (errorWithoutValidation, errorWithValidation) = {
- if (algo == Classification) {
- val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
- (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd))
- } else {
- (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
- }
- }
- assert(errorWithValidation <= errorWithoutValidation)
-
- // Test that results from evaluateEachIteration comply with runWithValidation.
- // Note that convergenceTol is set to 0.0
- val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
- assert(evaluationArray.length === numIterations)
- assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
- var i = 1
- while (i < numTrees) {
- assert(evaluationArray(i) <= evaluationArray(i - 1))
- i += 1
- }
- }
- }
-
test("Checkpointing") {
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
@@ -220,7 +177,7 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
}
-private object GradientBoostedTreesSuite {
+private[spark] object GradientBoostedTreesSuite {
// Combinations for estimators, learning rates and subsamplingRate
val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75))
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
index ebcd591465..cb1efd5251 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala
@@ -17,14 +17,20 @@
package org.apache.spark.mllib.util
-import org.scalatest.{BeforeAndAfterAll, Suite}
+import java.io.File
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.Suite
import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.ml.util.TempDirectory
import org.apache.spark.sql.SQLContext
+import org.apache.spark.util.Utils
-trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite =>
+trait MLlibTestSparkContext extends TempDirectory { self: Suite =>
@transient var sc: SparkContext = _
@transient var sqlContext: SQLContext = _
+ @transient var checkpointDir: String = _
override def beforeAll() {
super.beforeAll()
@@ -35,10 +41,13 @@ trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite =>
SQLContext.clearActive()
sqlContext = new SQLContext(sc)
SQLContext.setActive(sqlContext)
+ checkpointDir = Utils.createDirectory(tempDir.getCanonicalPath, "checkpoints").toString
+ sc.setCheckpointDir(checkpointDir)
}
override def afterAll() {
try {
+ Utils.deleteRecursively(new File(checkpointDir))
sqlContext = null
SQLContext.clearActive()
if (sc != null) {
diff --git a/pom.xml b/pom.xml
index bd82233f0c..cf17fe788c 100644
--- a/pom.xml
+++ b/pom.xml
@@ -97,6 +97,7 @@
dependencies are not 2.12-ready we need to conditionally enable those modules via the
2.11 and 2.10 build profiles -->
<!-- <module>mllib</module> -->
+ <!-- <module>mllib-local</module> -->
<module>tools</module>
<module>streaming</module>
<module>sql/catalyst</module>
@@ -134,7 +135,7 @@
<curator.version>2.4.0</curator.version>
<hive.group>org.spark-project.hive</hive.group>
<!-- Version used in Maven Hive dependency -->
- <hive.version>1.2.1.spark</hive.version>
+ <hive.version>1.2.1.spark2</hive.version>
<!-- Version used for internal directory structure -->
<hive.version.short>1.2.1</hive.version.short>
<derby.version>10.10.1.1</derby.version>
@@ -142,7 +143,7 @@
<hive.parquet.version>1.6.0</hive.parquet.version>
<jetty.version>8.1.14.v20131031</jetty.version>
<orbit.version>3.0.0.v201112011016</orbit.version>
- <chill.version>0.7.4</chill.version>
+ <chill.version>0.8.0</chill.version>
<ivy.version>2.4.0</ivy.version>
<oro.version>2.0.8</oro.version>
<codahale.metrics.version>3.1.2</codahale.metrics.version>
@@ -154,19 +155,20 @@
<aws.kinesis.producer.version>0.10.2</aws.kinesis.producer.version>
<!-- org.apache.httpcomponents/httpclient-->
<commons.httpclient.version>4.3.2</commons.httpclient.version>
+ <commons.httpcore.version>4.3.2</commons.httpcore.version>
<!-- commons-httpclient/commons-httpclient-->
<httpclient.classic.version>3.1</httpclient.classic.version>
<commons.math3.version>3.4.1</commons.math3.version>
<!-- managed up from 3.2.1 for SPARK-11652 -->
<commons.collections.version>3.2.2</commons.collections.version>
- <scala.version>2.11.7</scala.version>
+ <scala.version>2.11.8</scala.version>
<scala.binary.version>2.11</scala.binary.version>
<jline.version>${scala.version}</jline.version>
<jline.groupid>org.scala-lang</jline.groupid>
<codehaus.jackson.version>1.9.13</codehaus.jackson.version>
<fasterxml.jackson.version>2.5.3</fasterxml.jackson.version>
<fasterxml.jackson.scala.version>2.5.3</fasterxml.jackson.scala.version>
- <snappy.version>1.1.2.1</snappy.version>
+ <snappy.version>1.1.2.4</snappy.version>
<netlib.java.version>1.1.2</netlib.java.version>
<calcite.version>1.2.0-incubating</calcite.version>
<commons-codec.version>1.10</commons-codec.version>
@@ -181,15 +183,22 @@
<jodd.version>3.5.2</jodd.version>
<jsr305.version>1.3.9</jsr305.version>
<libthrift.version>0.9.2</libthrift.version>
- <antlr.version>3.5.2</antlr.version>
+ <antlr4.version>4.5.2-1</antlr4.version>
<json4s.version>3.2.2</json4s.version>
<test.java.home>${java.home}</test.java.home>
<test.exclude.tags></test.exclude.tags>
+ <!-- Package to use when relocating shaded classes. -->
+ <spark.shade.packageName>org.spark_project</spark.shade.packageName>
+
<!-- Modules that copy jars to the build directory should do so under this location. -->
<jars.target.dir>${project.build.directory}/scala-${scala.binary.version}/jars</jars.target.dir>
+ <!-- Allow modules to enable / disable certain build plugins easily. -->
+ <build.testJarPhase>prepare-package</build.testJarPhase>
+ <build.copyDependenciesPhase>none</build.copyDependenciesPhase>
+
<!--
Dependency scopes that can be overridden by enabling certain profiles. These profiles are
declared in the projects that build assemblies.
@@ -243,15 +252,6 @@
</pluginRepositories>
<dependencies>
<!--
- This is a dummy dependency that is used along with the shading plug-in
- to create effective poms on publishing (see SPARK-3812).
- -->
- <dependency>
- <groupId>org.spark-project.spark</groupId>
- <artifactId>unused</artifactId>
- <version>1.0.0</version>
- </dependency>
- <!--
This is needed by the scalatest plugin, and so is declared here to be available in
all child modules, just as scalatest is run in all children
-->
@@ -283,31 +283,11 @@
<groupId>com.twitter</groupId>
<artifactId>chill_${scala.binary.version}</artifactId>
<version>${chill.version}</version>
- <exclusions>
- <exclusion>
- <groupId>org.ow2.asm</groupId>
- <artifactId>asm</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.ow2.asm</groupId>
- <artifactId>asm-commons</artifactId>
- </exclusion>
- </exclusions>
</dependency>
<dependency>
<groupId>com.twitter</groupId>
<artifactId>chill-java</artifactId>
<version>${chill.version}</version>
- <exclusions>
- <exclusion>
- <groupId>org.ow2.asm</groupId>
- <artifactId>asm</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.ow2.asm</groupId>
- <artifactId>asm-commons</artifactId>
- </exclusion>
- </exclusions>
</dependency>
<!-- This artifact is a shaded version of ASM 5.0.4. The POM that was used to produce this
is at https://github.com/apache/geronimo-xbean/tree/xbean-4.4/xbean-asm5-shaded
@@ -419,7 +399,7 @@
<dependency>
<groupId>org.apache.httpcomponents</groupId>
<artifactId>httpcore</artifactId>
- <version>${commons.httpclient.version}</version>
+ <version>${commons.httpcore.version}</version>
</dependency>
<dependency>
<groupId>org.seleniumhq.selenium</groupId>
@@ -615,6 +595,28 @@
<scope>${hadoop.deps.scope}</scope>
</dependency>
<dependency>
+ <groupId>org.scalanlp</groupId>
+ <artifactId>breeze_${scala.binary.version}</artifactId>
+ <version>0.11.2</version>
+ <exclusions>
+ <!-- This is included as a compile-scoped dependency by jtransforms, which is
+ a dependency of breeze. -->
+ <exclusion>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.apache.commons</groupId>
+ <artifactId>commons-math3</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
+ <groupId>org.json4s</groupId>
+ <artifactId>json4s-jackson_${scala.binary.version}</artifactId>
+ <version>3.2.10</version>
+ </dependency>
+ <dependency>
<groupId>com.sun.jersey</groupId>
<artifactId>jersey-json</artifactId>
<version>${jersey.version}</version>
@@ -696,7 +698,7 @@
<groupId>com.spotify</groupId>
<artifactId>docker-client</artifactId>
<classifier>shaded</classifier>
- <version>3.4.0</version>
+ <version>3.6.6</version>
<scope>test</scope>
<exclusions>
<exclusion>
@@ -743,6 +745,10 @@
<groupId>org.jboss.netty</groupId>
<artifactId>netty</artifactId>
</exclusion>
+ <exclusion>
+ <groupId>jline</groupId>
+ <artifactId>jline</artifactId>
+ </exclusion>
</exclusions>
</dependency>
<dependency>
@@ -839,6 +845,14 @@
</exclusion>
</exclusions>
</dependency>
+ <!-- avro-mapred for some reason depends on avro-ipc's test jar, so undo that. -->
+ <dependency>
+ <groupId>org.apache.avro</groupId>
+ <artifactId>avro-ipc</artifactId>
+ <classifier>tests</classifier>
+ <version>${avro.version}</version>
+ <scope>test</scope>
+ </dependency>
<dependency>
<groupId>org.apache.avro</groupId>
<artifactId>avro-mapred</artifactId>
@@ -1527,6 +1541,10 @@
<groupId>org.codehaus.groovy</groupId>
<artifactId>groovy-all</artifactId>
</exclusion>
+ <exclusion>
+ <groupId>javax.servlet</groupId>
+ <artifactId>servlet-api</artifactId>
+ </exclusion>
</exclusions>
</dependency>
@@ -1766,8 +1784,8 @@
</dependency>
<dependency>
<groupId>org.antlr</groupId>
- <artifactId>antlr-runtime</artifactId>
- <version>${antlr.version}</version>
+ <artifactId>antlr4-runtime</artifactId>
+ <version>${antlr4.version}</version>
</dependency>
</dependencies>
</dependencyManagement>
@@ -1895,6 +1913,11 @@
<artifactId>antlr3-maven-plugin</artifactId>
<version>3.5.2</version>
</plugin>
+ <plugin>
+ <groupId>org.antlr</groupId>
+ <artifactId>antlr4-maven-plugin</artifactId>
+ <version>${antlr4.version}</version>
+ </plugin>
<!-- Surefire runs all Java tests -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
@@ -1917,10 +1940,12 @@
-->
<SPARK_DIST_CLASSPATH>${test_classpath}</SPARK_DIST_CLASSPATH>
<SPARK_PREPEND_CLASSES>1</SPARK_PREPEND_CLASSES>
+ <SPARK_SCALA_VERSION>${scala.binary.version}</SPARK_SCALA_VERSION>
<SPARK_TESTING>1</SPARK_TESTING>
<JAVA_HOME>${test.java.home}</JAVA_HOME>
</environmentVariables>
<systemProperties>
+ <log4j.configuration>file:src/test/resources/log4j.properties</log4j.configuration>
<derby.system.durability>test</derby.system.durability>
<java.awt.headless>true</java.awt.headless>
<java.io.tmpdir>${project.build.directory}/tmp</java.io.tmpdir>
@@ -1936,6 +1961,14 @@
<failIfNoTests>false</failIfNoTests>
<excludedGroups>${test.exclude.tags}</excludedGroups>
</configuration>
+ <executions>
+ <execution>
+ <id>test</id>
+ <goals>
+ <goal>test</goal>
+ </goals>
+ </execution>
+ </executions>
</plugin>
<!-- Scalatest runs all Scala tests -->
<plugin>
@@ -1956,10 +1989,12 @@
-->
<SPARK_DIST_CLASSPATH>${test_classpath}</SPARK_DIST_CLASSPATH>
<SPARK_PREPEND_CLASSES>1</SPARK_PREPEND_CLASSES>
+ <SPARK_SCALA_VERSION>${scala.binary.version}</SPARK_SCALA_VERSION>
<SPARK_TESTING>1</SPARK_TESTING>
<JAVA_HOME>${test.java.home}</JAVA_HOME>
</environmentVariables>
<systemProperties>
+ <log4j.configuration>file:src/test/resources/log4j.properties</log4j.configuration>
<derby.system.durability>test</derby.system.durability>
<java.awt.headless>true</java.awt.headless>
<java.io.tmpdir>${project.build.directory}/tmp</java.io.tmpdir>
@@ -2137,6 +2172,7 @@
<version>2.10</version>
<executions>
<execution>
+ <id>generate-test-classpath</id>
<phase>test-compile</phase>
<goals>
<goal>build-classpath</goal>
@@ -2146,6 +2182,17 @@
<outputProperty>test_classpath</outputProperty>
</configuration>
</execution>
+ <execution>
+ <id>copy-module-dependencies</id>
+ <phase>${build.copyDependenciesPhase}</phase>
+ <goals>
+ <goal>copy-dependencies</goal>
+ </goals>
+ <configuration>
+ <includeScope>runtime</includeScope>
+ <outputDirectory>${jars.target.dir}</outputDirectory>
+ </configuration>
+ </execution>
</executions>
</plugin>
@@ -2160,9 +2207,6 @@
<shadedArtifactAttached>false</shadedArtifactAttached>
<artifactSet>
<includes>
- <!-- At a minimum we must include this to force effective pom generation -->
- <include>org.spark-project.spark:unused</include>
-
<include>org.eclipse.jetty:jetty-io</include>
<include>org.eclipse.jetty:jetty-http</include>
<include>org.eclipse.jetty:jetty-continuation</include>
@@ -2177,14 +2221,14 @@
<relocations>
<relocation>
<pattern>org.eclipse.jetty</pattern>
- <shadedPattern>org.spark-project.jetty</shadedPattern>
+ <shadedPattern>${spark.shade.packageName}.jetty</shadedPattern>
<includes>
<include>org.eclipse.jetty.**</include>
</includes>
</relocation>
<relocation>
<pattern>com.google.common</pattern>
- <shadedPattern>org.spark-project.guava</shadedPattern>
+ <shadedPattern>${spark.shade.packageName}.guava</shadedPattern>
</relocation>
</relocations>
</configuration>
@@ -2242,7 +2286,7 @@
<failOnViolation>false</failOnViolation>
<includeTestSourceDirectory>true</includeTestSourceDirectory>
<failOnWarning>false</failOnWarning>
- <sourceDirectory>${basedir}/src/main/java</sourceDirectory>
+ <sourceDirectories>${basedir}/src/main/java,${basedir}/src/main/scala</sourceDirectories>
<testSourceDirectory>${basedir}/src/test/java</testSourceDirectory>
<configLocation>dev/checkstyle.xml</configLocation>
<outputFile>${basedir}/target/checkstyle-output.xml</outputFile>
@@ -2293,7 +2337,7 @@
<executions>
<execution>
<id>prepare-test-jar</id>
- <phase>prepare-package</phase>
+ <phase>${build.testJarPhase}</phase>
<goals>
<goal>test-jar</goal>
</goals>
@@ -2344,27 +2388,12 @@
<profile>
<id>java8-tests</id>
- <build>
- <plugins>
- <!-- Needed for publishing test jars as it is needed by java8-tests -->
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-jar-plugin</artifactId>
- <executions>
- <execution>
- <goals>
- <goal>test-jar</goal>
- </goals>
- </execution>
- </executions>
- </plugin>
- </plugins>
- </build>
-
+ <activation>
+ <jdk>[1.8,)</jdk>
+ </activation>
<modules>
<module>external/java8-tests</module>
</modules>
-
</profile>
<profile>
@@ -2454,7 +2483,7 @@
<property><name>scala-2.10</name></property>
</activation>
<properties>
- <scala.version>2.10.5</scala.version>
+ <scala.version>2.10.6</scala.version>
<scala.binary.version>2.10</scala.binary.version>
<jline.version>${scala.version}</jline.version>
<jline.groupid>org.scala-lang</jline.groupid>
@@ -2463,6 +2492,7 @@
dependencies are 2.12-ready -->
<modules>
<module>mllib</module>
+ <module>mllib-local</module>
<module>external/kafka</module>
<module>external/kafka-assembly</module>
</modules>
@@ -2494,6 +2524,7 @@
dependencies are 2.12-ready -->
<modules>
<module>mllib</module>
+ <module>mllib-local</module>
<module>external/kafka</module>
<module>external/kafka-assembly</module>
</modules>
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 208c7a28cf..71f337ce1f 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -66,7 +66,17 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.HadoopFsRelation$FileStatusCache")
) ++ Seq(
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.SparkContext.emptyRDD"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.broadcast.HttpBroadcastFactory")
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.broadcast.HttpBroadcastFactory"),
+ // SPARK-14358 SparkListener from trait to abstract class
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.SparkContext.addSparkListener"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.JavaSparkListener"),
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkFirehoseListener"),
+ ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.scheduler.SparkListener"),
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.jobs.JobProgressListener"),
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.exec.ExecutorsListener"),
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.env.EnvironmentListener"),
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.storage.StorageListener"),
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.storage.StorageStatusListener")
) ++
Seq(
// SPARK-3369 Fix Iterable/Iterator in Java API
@@ -319,6 +329,11 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.GroupedDataset"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Dataset.subtract"),
+ // [SPARK-14451][SQL] Move encoder definition into Aggregator interface
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.toColumn"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.bufferEncoder"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.outputEncoder"),
+
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MultilabelMetrics.this"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions"),
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions")
@@ -589,6 +604,33 @@ object MimaExcludes {
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.util.MLUtils.loadLabeledData"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.optimization.LBFGS.setMaxNumIterations"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.setScoreCol")
+ ) ++ Seq(
+ // [SPARK-14205][SQL] remove trait Queryable
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.Dataset")
+ ) ++ Seq(
+ // [SPARK-11262][ML] Unit test for gradient, loss layers, memory management
+ // for multilayer perceptron.
+ // This class is marked as `private`.
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.ann.SoftmaxFunction")
+ ) ++ Seq(
+ // [SPARK-13674][SQL] Add wholestage codegen support to Sample
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.util.random.PoissonSampler.this"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.random.PoissonSampler.this")
+ ) ++ Seq(
+ // [SPARK-13430][ML] moved featureCol from LinearRegressionModelSummary to LinearRegressionSummary
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.this")
+ ) ++ Seq(
+ // [SPARK-14437][Core] Use the address that NettyBlockTransferService listens to create BlockManagerId
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockTransferService.this")
+ ) ++ Seq(
+ // [SPARK-13048][ML][MLLIB] keepLastCheckpoint option for LDA EM optimizer
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.DistributedLDAModel.this")
+ ) ++ Seq(
+ // [SPARK-14475] Propagate user-defined context from driver to executors
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getLocalProperty"),
+ // [SPARK-14617] Remove deprecated APIs in TaskMetrics
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.InputMetrics$"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.OutputMetrics$")
)
case v if v.startsWith("1.6") =>
Seq(
@@ -808,7 +850,6 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource$"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopPartition"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DefaultWriterContainer"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues"),
@@ -817,10 +858,8 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DynamicPartitionWriterContainer"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation$"),
@@ -831,7 +870,6 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect$"),
@@ -845,7 +883,6 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CaseInsensitiveMap"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLException"),
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 6c10d740f7..cb85c63382 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -20,11 +20,13 @@ import java.nio.file.Files
import scala.util.Properties
import scala.collection.JavaConverters._
+import scala.collection.mutable.Stack
import sbt._
import sbt.Classpaths.publishTask
import sbt.Keys._
import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion
+import com.simplytyped.Antlr4Plugin._
import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys}
import com.typesafe.tools.mima.plugin.MimaKeys
@@ -45,9 +47,9 @@ object BuildCommons {
).map(ProjectRef(buildLocation, _))
val allProjects@Seq(
- core, graphx, mllib, repl, networkCommon, networkShuffle, launcher, unsafe, testTags, sketch, _*
+ core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, testTags, sketch, _*
) = Seq(
- "core", "graphx", "mllib", "repl", "network-common", "network-shuffle", "launcher", "unsafe",
+ "core", "graphx", "mllib", "mllib-local", "repl", "network-common", "network-shuffle", "launcher", "unsafe",
"test-tags", "sketch"
).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects
@@ -56,11 +58,12 @@ object BuildCommons {
Seq("yarn", "java8-tests", "ganglia-lgpl", "streaming-kinesis-asl",
"docker-integration-tests").map(ProjectRef(buildLocation, _))
- val assemblyProjects@Seq(assembly, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKinesisAslAssembly) =
- Seq("assembly", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-kinesis-asl-assembly")
+ val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKinesisAslAssembly) =
+ Seq("network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-kinesis-asl-assembly")
.map(ProjectRef(buildLocation, _))
- val copyJarsProjects@Seq(examples) = Seq("examples").map(ProjectRef(buildLocation, _))
+ val copyJarsProjects@Seq(assembly, examples) = Seq("assembly", "examples")
+ .map(ProjectRef(buildLocation, _))
val tools = ProjectRef(buildLocation, "tools")
// Root project.
@@ -262,7 +265,7 @@ object SparkBuild extends PomBuild {
allProjects.filterNot { x =>
Seq(
spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn,
- unsafe, testTags, sketch
+ unsafe, testTags, sketch, mllibLocal
).contains(x)
}
} else {
@@ -276,8 +279,14 @@ object SparkBuild extends PomBuild {
/* Unsafe settings */
enable(Unsafe.settings)(unsafe)
- /* Set up tasks to copy dependencies during packaging. */
- copyJarsProjects.foreach(enable(CopyDependencies.settings))
+ /*
+ * Set up tasks to copy dependencies during packaging. This step can be disabled in the command
+ * line, so that dev/mima can run without trying to copy these files again and potentially
+ * causing issues.
+ */
+ if (!"false".equals(System.getProperty("copyDependencies"))) {
+ copyJarsProjects.foreach(enable(CopyDependencies.settings))
+ }
/* Enable Assembly for all assembly projects */
assemblyProjects.foreach(enable(Assembly.settings))
@@ -371,8 +380,10 @@ object Flume {
object DockerIntegrationTests {
// This serves to override the override specified in DependencyOverrides:
lazy val settings = Seq(
- dependencyOverrides += "com.google.guava" % "guava" % "18.0"
+ dependencyOverrides += "com.google.guava" % "guava" % "18.0",
+ resolvers ++= Seq("DB2" at "https://app.camunda.com/nexus/content/repositories/public/")
)
+
}
/**
@@ -415,57 +426,10 @@ object OldDeps {
}
object Catalyst {
- lazy val settings = Seq(
- // ANTLR code-generation step.
- //
- // This has been heavily inspired by com.github.stefri.sbt-antlr (0.5.3). It fixes a number of
- // build errors in the current plugin.
- // Create Parser from ANTLR grammar files.
- sourceGenerators in Compile += Def.task {
- val log = streams.value.log
-
- val grammarFileNames = Seq(
- "SparkSqlLexer.g",
- "SparkSqlParser.g")
- val sourceDir = (sourceDirectory in Compile).value / "antlr3"
- val targetDir = (sourceManaged in Compile).value
-
- // Create default ANTLR Tool.
- val antlr = new org.antlr.Tool
-
- // Setup input and output directories.
- antlr.setInputDirectory(sourceDir.getPath)
- antlr.setOutputDirectory(targetDir.getPath)
- antlr.setForceRelativeOutput(true)
- antlr.setMake(true)
-
- // Add grammar files.
- grammarFileNames.flatMap(gFileName => (sourceDir ** gFileName).get).foreach { gFilePath =>
- val relGFilePath = (gFilePath relativeTo sourceDir).get.getPath
- log.info("ANTLR: Grammar file '%s' detected.".format(relGFilePath))
- antlr.addGrammarFile(relGFilePath)
- // We will set library directory multiple times here. However, only the
- // last one has effect. Because the grammar files are located under the same directory,
- // We assume there is only one library directory.
- antlr.setLibDirectory(gFilePath.getParent)
- }
-
- // Generate the parser.
- antlr.process()
- val errorState = org.antlr.tool.ErrorManager.getErrorState
- if (errorState.errors > 0) {
- sys.error("ANTLR: Caught %d build errors.".format(errorState.errors))
- } else if (errorState.warnings > 0) {
- sys.error("ANTLR: Caught %d build warnings.".format(errorState.warnings))
- }
-
- // Return all generated java files.
- (targetDir ** "*.java").get.toSeq
- }.taskValue,
- // Include ANTLR tokens files.
- resourceGenerators in Compile += Def.task {
- ((sourceManaged in Compile).value ** "*.tokens").get.toSeq
- }.taskValue
+ lazy val settings = antlr4Settings ++ Seq(
+ antlr4PackageName in Antlr4 := Some("org.apache.spark.sql.catalyst.parser"),
+ antlr4GenListener in Antlr4 := true,
+ antlr4GenVisitor in Antlr4 := true
)
}
@@ -537,8 +501,6 @@ object Assembly {
val hadoopVersion = taskKey[String]("The version of hadoop that spark is compiled against.")
- val deployDatanucleusJars = taskKey[Unit]("Deploy datanucleus jars to the spark/lib_managed/jars directory")
-
lazy val settings = assemblySettings ++ Seq(
test in assembly := {},
hadoopVersion := {
@@ -557,27 +519,13 @@ object Assembly {
s"${mName}-test-${v}.jar"
},
mergeStrategy in assembly := {
- case PathList("org", "datanucleus", xs @ _*) => MergeStrategy.discard
case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard
case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard
case "log4j.properties" => MergeStrategy.discard
case m if m.toLowerCase.startsWith("meta-inf/services/") => MergeStrategy.filterDistinctLines
case "reference.conf" => MergeStrategy.concat
case _ => MergeStrategy.first
- },
- deployDatanucleusJars := {
- val jars: Seq[File] = (fullClasspath in assembly).value.map(_.data)
- .filter(_.getPath.contains("org.datanucleus"))
- var libManagedJars = new File(BuildCommons.sparkHome, "lib_managed/jars")
- libManagedJars.mkdirs()
- jars.foreach { jar =>
- val dest = new File(libManagedJars, jar.getName)
- if (!dest.exists()) {
- Files.copy(jar.toPath, dest.toPath)
- }
- }
- },
- assembly <<= assembly.dependsOn(deployDatanucleusJars)
+ }
)
}
@@ -758,6 +706,13 @@ object Java8TestSettings {
object TestSettings {
import BuildCommons._
+ private val scalaBinaryVersion =
+ if (System.getProperty("scala-2.10") == "true") {
+ "2.10"
+ } else {
+ "2.11"
+ }
+
lazy val settings = Seq (
// Fork new JVMs for tests and set Java options for those
fork := true,
@@ -767,6 +722,7 @@ object TestSettings {
"SPARK_DIST_CLASSPATH" ->
(fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":"),
"SPARK_PREPEND_CLASSES" -> "1",
+ "SPARK_SCALA_VERSION" -> scalaBinaryVersion,
"SPARK_TESTING" -> "1",
"JAVA_HOME" -> sys.env.get("JAVA_HOME").getOrElse(sys.props("java.home"))),
javaOptions in Test += s"-Djava.io.tmpdir=$testTempDir",
@@ -803,8 +759,21 @@ object TestSettings {
parallelExecution in Test := false,
// Make sure the test temp directory exists.
resourceGenerators in Test <+= resourceManaged in Test map { outDir: File =>
- if (!new File(testTempDir).isDirectory()) {
- require(new File(testTempDir).mkdirs())
+ var dir = new File(testTempDir)
+ if (!dir.isDirectory()) {
+ // Because File.mkdirs() can fail if multiple callers are trying to create the same
+ // parent directory, this code tries to create parents one at a time, and avoids
+ // failures when the directories have been created by somebody else.
+ val stack = new Stack[File]()
+ while (!dir.isDirectory()) {
+ stack.push(dir)
+ dir = dir.getParentFile()
+ }
+
+ while (stack.nonEmpty) {
+ val d = stack.pop()
+ require(d.mkdir() || d.isDirectory(), s"Failed to create directory $d")
+ }
}
Seq[File]()
},
diff --git a/project/plugins.sbt b/project/plugins.sbt
index eeca94a47c..44ec3a12ae 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -2,8 +2,6 @@ addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2")
addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "4.0.0")
-addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0")
-
addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.8.2")
addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.8.0")
@@ -22,4 +20,7 @@ libraryDependencies += "org.ow2.asm" % "asm" % "5.0.3"
libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.0.3"
-libraryDependencies += "org.antlr" % "antlr" % "3.5.2"
+// TODO I am not sure we want such a dep.
+resolvers += "simplytyped" at "http://simplytyped.github.io/repo/releases"
+
+addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.7.10")
diff --git a/python/docs/Makefile b/python/docs/Makefile
index 903009790b..905e0215c2 100644
--- a/python/docs/Makefile
+++ b/python/docs/Makefile
@@ -2,10 +2,10 @@
#
# You can set these variables from the command line.
-SPHINXOPTS =
-SPHINXBUILD = sphinx-build
-PAPER =
-BUILDDIR = _build
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+PAPER ?=
+BUILDDIR ?= _build
export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.9.2-src.zip)
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index 663c9abe08..a0b819220e 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -99,11 +99,26 @@ class Broadcast(object):
def unpersist(self, blocking=False):
"""
- Delete cached copies of this broadcast on the executors.
+ Delete cached copies of this broadcast on the executors. If the
+ broadcast is used after this is called, it will need to be
+ re-sent to each executor.
+
+ :param blocking: Whether to block until unpersisting has completed
"""
if self._jbroadcast is None:
raise Exception("Broadcast can only be unpersisted in driver")
self._jbroadcast.unpersist(blocking)
+
+ def destroy(self):
+ """
+ Destroy all data and metadata related to this broadcast variable.
+ Use this with caution; once a broadcast variable has been destroyed,
+ it cannot be used again. This method blocks until destroy has
+ completed.
+ """
+ if self._jbroadcast is None:
+ raise Exception("Broadcast can only be destroyed in driver")
+ self._jbroadcast.destroy()
os.unlink(self._path)
def __reduce__(self):
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 529d16b480..cb15b4b91f 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -428,15 +428,19 @@ class SparkContext(object):
# because it sends O(n) Py4J commands. As an alternative, serialized
# objects are written to a file and loaded through textFile().
tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
- # Make sure we distribute data evenly if it's smaller than self.batchSize
- if "__len__" not in dir(c):
- c = list(c) # Make it a list so we can compute its length
- batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
- serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
- serializer.dump_stream(c, tempFile)
- tempFile.close()
- readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
- jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
+ try:
+ # Make sure we distribute data evenly if it's smaller than self.batchSize
+ if "__len__" not in dir(c):
+ c = list(c) # Make it a list so we can compute its length
+ batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
+ serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
+ serializer.dump_stream(c, tempFile)
+ tempFile.close()
+ readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
+ jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
+ finally:
+ # readRDDFromFile eagerily reads the file so we can delete right after.
+ os.unlink(tempFile.name)
return RDD(jrdd, self, serializer)
def pickleFile(self, name, minPartitions=None):
diff --git a/python/pyspark/join.py b/python/pyspark/join.py
index 94df399016..c1f5362648 100644
--- a/python/pyspark/join.py
+++ b/python/pyspark/join.py
@@ -93,7 +93,7 @@ def python_full_outer_join(rdd, other, numPartitions):
vbuf.append(None)
if not wbuf:
wbuf.append(None)
- return [(v, w) for v in vbuf for w in wbuf]
+ return ((v, w) for v in vbuf for w in wbuf)
return _do_python_join(rdd, other, numPartitions, dispatch)
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index d51b80e16c..922f8069fa 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -19,15 +19,18 @@ import warnings
from pyspark import since
from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaEstimator, JavaModel
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper
from pyspark.ml.param import TypeConverters
from pyspark.ml.param.shared import *
from pyspark.ml.regression import (
RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels)
from pyspark.mllib.common import inherit_doc
+from pyspark.sql import DataFrame
__all__ = ['LogisticRegression', 'LogisticRegressionModel',
+ 'LogisticRegressionSummary', 'LogisticRegressionTrainingSummary',
+ 'BinaryLogisticRegressionSummary', 'BinaryLogisticRegressionTrainingSummary',
'DecisionTreeClassifier', 'DecisionTreeClassificationModel',
'GBTClassifier', 'GBTClassificationModel',
'RandomForestClassifier', 'RandomForestClassificationModel',
@@ -233,6 +236,219 @@ class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
return self._call_java("intercept")
+ @property
+ @since("2.0.0")
+ def summary(self):
+ """
+ Gets summary (e.g. residuals, mse, r-squared ) of model on
+ training set. An exception is thrown if
+ `trainingSummary is None`.
+ """
+ java_blrt_summary = self._call_java("summary")
+ # Note: Once multiclass is added, update this to return correct summary
+ return BinaryLogisticRegressionTrainingSummary(java_blrt_summary)
+
+ @property
+ @since("2.0.0")
+ def hasSummary(self):
+ """
+ Indicates whether a training summary exists for this model
+ instance.
+ """
+ return self._call_java("hasSummary")
+
+ @since("2.0.0")
+ def evaluate(self, dataset):
+ """
+ Evaluates the model on a test dataset.
+
+ :param dataset:
+ Test dataset to evaluate model on, where dataset is an
+ instance of :py:class:`pyspark.sql.DataFrame`
+ """
+ if not isinstance(dataset, DataFrame):
+ raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
+ java_blr_summary = self._call_java("evaluate", dataset)
+ return BinaryLogisticRegressionSummary(java_blr_summary)
+
+
+class LogisticRegressionSummary(JavaWrapper):
+ """
+ Abstraction for Logistic Regression Results for a given model.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @property
+ @since("2.0.0")
+ def predictions(self):
+ """
+ Dataframe outputted by the model's `transform` method.
+ """
+ return self._call_java("predictions")
+
+ @property
+ @since("2.0.0")
+ def probabilityCol(self):
+ """
+ Field in "predictions" which gives the probability
+ of each class as a vector.
+ """
+ return self._call_java("probabilityCol")
+
+ @property
+ @since("2.0.0")
+ def labelCol(self):
+ """
+ Field in "predictions" which gives the true label of each
+ instance.
+ """
+ return self._call_java("labelCol")
+
+ @property
+ @since("2.0.0")
+ def featuresCol(self):
+ """
+ Field in "predictions" which gives the features of each instance
+ as a vector.
+ """
+ return self._call_java("featuresCol")
+
+
+@inherit_doc
+class LogisticRegressionTrainingSummary(LogisticRegressionSummary):
+ """
+ Abstraction for multinomial Logistic Regression Training results.
+ Currently, the training summary ignores the training weights except
+ for the objective trace.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @property
+ @since("2.0.0")
+ def objectiveHistory(self):
+ """
+ Objective function (scaled loss + regularization) at each
+ iteration.
+ """
+ return self._call_java("objectiveHistory")
+
+ @property
+ @since("2.0.0")
+ def totalIterations(self):
+ """
+ Number of training iterations until termination.
+ """
+ return self._call_java("totalIterations")
+
+
+@inherit_doc
+class BinaryLogisticRegressionSummary(LogisticRegressionSummary):
+ """
+ .. note:: Experimental
+
+ Binary Logistic regression results for a given model.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @property
+ @since("2.0.0")
+ def roc(self):
+ """
+ Returns the receiver operating characteristic (ROC) curve,
+ which is an Dataframe having two fields (FPR, TPR) with
+ (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
+ Reference: http://en.wikipedia.org/wiki/Receiver_operating_characteristic
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LogisticRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("roc")
+
+ @property
+ @since("2.0.0")
+ def areaUnderROC(self):
+ """
+ Computes the area under the receiver operating characteristic
+ (ROC) curve.
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LogisticRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("areaUnderROC")
+
+ @property
+ @since("2.0.0")
+ def pr(self):
+ """
+ Returns the precision-recall curve, which is an Dataframe
+ containing two fields recall, precision with (0.0, 1.0) prepended
+ to it.
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LogisticRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("pr")
+
+ @property
+ @since("2.0.0")
+ def fMeasureByThreshold(self):
+ """
+ Returns a dataframe with two fields (threshold, F-Measure) curve
+ with beta = 1.0.
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LogisticRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("fMeasureByThreshold")
+
+ @property
+ @since("2.0.0")
+ def precisionByThreshold(self):
+ """
+ Returns a dataframe with two fields (threshold, precision) curve.
+ Every possible probability obtained in transforming the dataset
+ are used as thresholds used in calculating the precision.
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LogisticRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("precisionByThreshold")
+
+ @property
+ @since("2.0.0")
+ def recallByThreshold(self):
+ """
+ Returns a dataframe with two fields (threshold, recall) curve.
+ Every possible probability obtained in transforming the dataset
+ are used as thresholds used in calculating the recall.
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LogisticRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("recallByThreshold")
+
+
+@inherit_doc
+class BinaryLogisticRegressionTrainingSummary(BinaryLogisticRegressionSummary,
+ LogisticRegressionTrainingSummary):
+ """
+ .. note:: Experimental
+
+ Binary Logistic regression training results for a given model.
+
+ .. versionadded:: 2.0.0
+ """
+ pass
+
class TreeClassifierParams(object):
"""
@@ -396,7 +612,7 @@ class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLR
- Normalize importances for tree to sum to 1.
Note: Feature importance for single decision trees can have high variance due to
- correlated predictor variables. Consider using a :class:`RandomForestClassifier`
+ correlated predictor variables. Consider using a :py:class:`RandomForestClassifier`
to determine feature importance instead.
"""
return self._call_java("featureImportances")
@@ -405,7 +621,8 @@ class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLR
@inherit_doc
class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed,
HasRawPredictionCol, HasProbabilityCol,
- RandomForestParams, TreeClassifierParams, HasCheckpointInterval):
+ RandomForestParams, TreeClassifierParams, HasCheckpointInterval,
+ JavaMLWritable, JavaMLReadable):
"""
`http://en.wikipedia.org/wiki/Random_forest Random Forest`
learning algorithm for classification.
@@ -439,6 +656,16 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
1.0
+ >>> rfc_path = temp_path + "/rfc"
+ >>> rf.save(rfc_path)
+ >>> rf2 = RandomForestClassifier.load(rfc_path)
+ >>> rf2.getNumTrees()
+ 3
+ >>> model_path = temp_path + "/rfc_model"
+ >>> model.save(model_path)
+ >>> model2 = RandomForestClassificationModel.load(model_path)
+ >>> model.featureImportances == model2.featureImportances
+ True
.. versionadded:: 1.4.0
"""
@@ -487,7 +714,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
return RandomForestClassificationModel(java_model)
-class RandomForestClassificationModel(TreeEnsembleModels):
+class RandomForestClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable):
"""
Model fitted by RandomForestClassifier.
@@ -500,16 +727,12 @@ class RandomForestClassificationModel(TreeEnsembleModels):
"""
Estimate of the importance of each feature.
- This generalizes the idea of "Gini" importance to other losses,
- following the explanation of Gini importance from "Random Forests" documentation
- by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ Each feature's importance is the average of its importance across all trees in the ensemble
+ The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ and follows the implementation from scikit-learn.
- This feature importance is calculated as follows:
- - Average over trees:
- - importance(feature j) = sum (over nodes which split on feature j) of the gain,
- where gain is scaled by the number of instances passing through node
- - Normalize importances for tree to sum to 1.
- - Normalize feature importance vector to sum to 1.
+ .. seealso:: :py:attr:`DecisionTreeClassificationModel.featureImportances`
"""
return self._call_java("featureImportances")
@@ -534,6 +757,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
>>> td = si_model.transform(df)
>>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42)
>>> model = gbt.fit(td)
+ >>> model.featureImportances
+ SparseVector(1, {0: 1.0})
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
@@ -613,6 +838,21 @@ class GBTClassificationModel(TreeEnsembleModels):
.. versionadded:: 1.4.0
"""
+ @property
+ @since("2.0.0")
+ def featureImportances(self):
+ """
+ Estimate of the importance of each feature.
+
+ Each feature's importance is the average of its importance across all trees in the ensemble
+ The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ and follows the implementation from scikit-learn.
+
+ .. seealso:: :py:attr:`DecisionTreeClassificationModel.featureImportances`
+ """
+ return self._call_java("featureImportances")
+
@inherit_doc
class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol,
@@ -762,7 +1002,7 @@ class NaiveBayesModel(JavaModel, JavaMLWritable, JavaMLReadable):
@inherit_doc
class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
- HasMaxIter, HasTol, HasSeed):
+ HasMaxIter, HasTol, HasSeed, JavaMLWritable, JavaMLReadable):
"""
Classifier trainer based on the Multilayer Perceptron.
Each layer has sigmoid activation function, output layer has softmax.
@@ -775,7 +1015,7 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
... (1.0, Vectors.dense([0.0, 1.0])),
... (1.0, Vectors.dense([1.0, 0.0])),
... (0.0, Vectors.dense([1.0, 1.0]))], ["label", "features"])
- >>> mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[2, 5, 2], blockSize=1, seed=11)
+ >>> mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[2, 5, 2], blockSize=1, seed=123)
>>> model = mlp.fit(df)
>>> model.layers
[2, 5, 2]
@@ -792,6 +1032,18 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
|[0.0,0.0]| 0.0|
+---------+----------+
...
+ >>> mlp_path = temp_path + "/mlp"
+ >>> mlp.save(mlp_path)
+ >>> mlp2 = MultilayerPerceptronClassifier.load(mlp_path)
+ >>> mlp2.getBlockSize()
+ 1
+ >>> model_path = temp_path + "/mlp_model"
+ >>> model.save(model_path)
+ >>> model2 = MultilayerPerceptronClassificationModel.load(model_path)
+ >>> model.layers == model2.layers
+ True
+ >>> model.weights == model2.weights
+ True
.. versionadded:: 1.6.0
"""
@@ -869,7 +1121,7 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
return self.getOrDefault(self.blockSize)
-class MultilayerPerceptronClassificationModel(JavaModel):
+class MultilayerPerceptronClassificationModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by MultilayerPerceptronClassifier.
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index e22d5c8ea4..f071c597c8 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -171,7 +171,7 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol
return self.getOrDefault(self.initSteps)
-class BisectingKMeansModel(JavaModel):
+class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental
@@ -195,7 +195,8 @@ class BisectingKMeansModel(JavaModel):
@inherit_doc
-class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasSeed):
+class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasSeed,
+ JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental
@@ -225,6 +226,18 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte
True
>>> rows[2].prediction == rows[3].prediction
True
+ >>> bkm_path = temp_path + "/bkm"
+ >>> bkm.save(bkm_path)
+ >>> bkm2 = BisectingKMeans.load(bkm_path)
+ >>> bkm2.getK()
+ 2
+ >>> model_path = temp_path + "/bkm_model"
+ >>> model.save(model_path)
+ >>> model2 = BisectingKMeansModel.load(model_path)
+ >>> model.clusterCenters()[0] == model2.clusterCenters()[0]
+ array([ True, True], dtype=bool)
+ >>> model.clusterCenters()[1] == model2.clusterCenters()[1]
+ array([ True, True], dtype=bool)
.. versionadded:: 2.0.0
"""
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index c9b95b3bf4..4b0bade102 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -18,7 +18,7 @@
from abc import abstractmethod, ABCMeta
from pyspark import since
-from pyspark.ml.wrapper import JavaWrapper
+from pyspark.ml.wrapper import JavaParams
from pyspark.ml.param import Param, Params
from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol
from pyspark.ml.util import keyword_only
@@ -81,7 +81,7 @@ class Evaluator(Params):
@inherit_doc
-class JavaEvaluator(Evaluator, JavaWrapper):
+class JavaEvaluator(JavaParams, Evaluator):
"""
Base class for :py:class:`Evaluator`s that wrap Java/Scala
implementations.
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 86b53285b5..809a513316 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -256,24 +256,33 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable,
vocabSize = Param(
Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.",
typeConverter=TypeConverters.toInt)
+ binary = Param(
+ Params._dummy(), "binary", "Binary toggle to control the output vector values." +
+ " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" +
+ " for discrete probabilistic models that model binary events rather than integer counts." +
+ " Default False", typeConverter=TypeConverters.toBoolean)
@keyword_only
- def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None):
+ def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,
+ outputCol=None):
"""
- __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None)
+ __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\
+ outputCol=None)
"""
super(CountVectorizer, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer",
self.uid)
- self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18)
+ self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("1.6.0")
- def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None):
+ def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,
+ outputCol=None):
"""
- setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None)
+ setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\
+ outputCol=None)
Set the params for the CountVectorizer
"""
kwargs = self.setParams._input_kwargs
@@ -324,6 +333,21 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable,
"""
return self.getOrDefault(self.vocabSize)
+ @since("2.0.0")
+ def setBinary(self, value):
+ """
+ Sets the value of :py:attr:`binary`.
+ """
+ self._paramMap[self.binary] = value
+ return self
+
+ @since("2.0.0")
+ def getBinary(self):
+ """
+ Gets the value of binary or its default value.
+ """
+ return self.getOrDefault(self.binary)
+
def _create_model(self, java_model):
return CountVectorizerModel(java_model)
@@ -512,14 +536,19 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, Java
.. versionadded:: 1.3.0
"""
+ binary = Param(Params._dummy(), "binary", "If True, all non zero counts are set to 1. " +
+ "This is useful for discrete probabilistic models that model binary events " +
+ "rather than integer counts. Default False.",
+ typeConverter=TypeConverters.toBoolean)
+
@keyword_only
- def __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None):
+ def __init__(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None):
"""
__init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None)
"""
super(HashingTF, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.HashingTF", self.uid)
- self._setDefault(numFeatures=1 << 18)
+ self._setDefault(numFeatures=1 << 18, binary=False)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@@ -533,6 +562,21 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, Java
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
+ @since("2.0.0")
+ def setBinary(self, value):
+ """
+ Sets the value of :py:attr:`binary`.
+ """
+ self._paramMap[self.binary] = value
+ return self
+
+ @since("2.0.0")
+ def getBinary(self):
+ """
+ Gets the value of binary or its default value.
+ """
+ return self.getOrDefault(self.binary)
+
@inherit_doc
class IDF(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index 715fa9e9f8..a7615c43be 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -146,7 +146,9 @@ if __name__ == "__main__":
("weightCol", "weight column name. If this is not set or empty, we treat " +
"all instance weights as 1.0.", None, "TypeConverters.toString"),
("solver", "the solver algorithm for optimization. If this is not set or empty, " +
- "default value is 'auto'.", "'auto'", "TypeConverters.toString")]
+ "default value is 'auto'.", "'auto'", "TypeConverters.toString"),
+ ("varianceCol", "column name for the biased sample variance of prediction.",
+ None, "TypeConverters.toString")]
code = []
for name, doc, defaultValueStr, typeConverter in shared:
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index d79d55e463..c9e975525c 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -559,6 +559,30 @@ class HasSolver(Params):
return self.getOrDefault(self.solver)
+class HasVarianceCol(Params):
+ """
+ Mixin for param varianceCol: column name for the biased sample variance of prediction.
+ """
+
+ varianceCol = Param(Params._dummy(), "varianceCol", "column name for the biased sample variance of prediction.", typeConverter=TypeConverters.toString)
+
+ def __init__(self):
+ super(HasVarianceCol, self).__init__()
+
+ def setVarianceCol(self, value):
+ """
+ Sets the value of :py:attr:`varianceCol`.
+ """
+ self._set(varianceCol=value)
+ return self
+
+ def getVarianceCol(self):
+ """
+ Gets the value of varianceCol or its default value.
+ """
+ return self.getOrDefault(self.varianceCol)
+
+
class DecisionTreeParams(Params):
"""
Mixin for Decision Tree parameters.
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 2b5504bc29..9d654e8b0f 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -25,7 +25,7 @@ from pyspark import since
from pyspark.ml import Estimator, Model, Transformer
from pyspark.ml.param import Param, Params
from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable
-from pyspark.ml.wrapper import JavaWrapper
+from pyspark.ml.wrapper import JavaParams
from pyspark.mllib.common import inherit_doc
@@ -177,7 +177,7 @@ class Pipeline(Estimator, MLReadable, MLWritable):
# Create a new instance of this stage.
py_stage = cls()
# Load information from java_stage to the instance.
- py_stages = [JavaWrapper._from_java(s) for s in java_stage.getStages()]
+ py_stages = [JavaParams._from_java(s) for s in java_stage.getStages()]
py_stage.setStages(py_stages)
py_stage._resetUid(java_stage.uid())
return py_stage
@@ -195,7 +195,7 @@ class Pipeline(Estimator, MLReadable, MLWritable):
for idx, stage in enumerate(self.getStages()):
java_stages[idx] = stage._to_java()
- _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.Pipeline", self.uid)
+ _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.Pipeline", self.uid)
_java_obj.setStages(java_stages)
return _java_obj
@@ -275,7 +275,7 @@ class PipelineModel(Model, MLReadable, MLWritable):
Used for ML persistence.
"""
# Load information from java_stage to the instance.
- py_stages = [JavaWrapper._from_java(s) for s in java_stage.stages()]
+ py_stages = [JavaParams._from_java(s) for s in java_stage.stages()]
# Create a new instance of this stage.
py_stage = cls(py_stages)
py_stage._resetUid(java_stage.uid())
@@ -295,6 +295,6 @@ class PipelineModel(Model, MLReadable, MLWritable):
java_stages[idx] = stage._to_java()
_java_obj =\
- JavaWrapper._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages)
+ JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages)
return _java_obj
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 37648549de..c064fe500c 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -20,15 +20,18 @@ import warnings
from pyspark import since
from pyspark.ml.param.shared import *
from pyspark.ml.util import *
-from pyspark.ml.wrapper import JavaEstimator, JavaModel
+from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper
from pyspark.mllib.common import inherit_doc
+from pyspark.sql import DataFrame
__all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
'DecisionTreeRegressor', 'DecisionTreeRegressionModel',
'GBTRegressor', 'GBTRegressionModel',
+ 'GeneralizedLinearRegression', 'GeneralizedLinearRegressionModel',
'IsotonicRegression', 'IsotonicRegressionModel',
'LinearRegression', 'LinearRegressionModel',
+ 'LinearRegressionSummary', 'LinearRegressionTrainingSummary',
'RandomForestRegressor', 'RandomForestRegressionModel']
@@ -131,7 +134,6 @@ class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model weights.
"""
-
warnings.warn("weights is deprecated. Use coefficients instead.")
return self._call_java("weights")
@@ -151,6 +153,255 @@ class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
return self._call_java("intercept")
+ @property
+ @since("2.0.0")
+ def summary(self):
+ """
+ Gets summary (e.g. residuals, mse, r-squared ) of model on
+ training set. An exception is thrown if
+ `trainingSummary is None`.
+ """
+ java_lrt_summary = self._call_java("summary")
+ return LinearRegressionTrainingSummary(java_lrt_summary)
+
+ @property
+ @since("2.0.0")
+ def hasSummary(self):
+ """
+ Indicates whether a training summary exists for this model
+ instance.
+ """
+ return self._call_java("hasSummary")
+
+ @since("2.0.0")
+ def evaluate(self, dataset):
+ """
+ Evaluates the model on a test dataset.
+
+ :param dataset:
+ Test dataset to evaluate model on, where dataset is an
+ instance of :py:class:`pyspark.sql.DataFrame`
+ """
+ if not isinstance(dataset, DataFrame):
+ raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
+ java_lr_summary = self._call_java("evaluate", dataset)
+ return LinearRegressionSummary(java_lr_summary)
+
+
+class LinearRegressionSummary(JavaWrapper):
+ """
+ .. note:: Experimental
+
+ Linear regression results evaluated on a dataset.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @property
+ @since("2.0.0")
+ def predictions(self):
+ """
+ Dataframe outputted by the model's `transform` method.
+ """
+ return self._call_java("predictions")
+
+ @property
+ @since("2.0.0")
+ def predictionCol(self):
+ """
+ Field in "predictions" which gives the predicted value of
+ the label at each instance.
+ """
+ return self._call_java("predictionCol")
+
+ @property
+ @since("2.0.0")
+ def labelCol(self):
+ """
+ Field in "predictions" which gives the true label of each
+ instance.
+ """
+ return self._call_java("labelCol")
+
+ @property
+ @since("2.0.0")
+ def featuresCol(self):
+ """
+ Field in "predictions" which gives the features of each instance
+ as a vector.
+ """
+ return self._call_java("featuresCol")
+
+ @property
+ @since("2.0.0")
+ def explainedVariance(self):
+ """
+ Returns the explained variance regression score.
+ explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
+ Reference: http://en.wikipedia.org/wiki/Explained_variation
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LinearRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("explainedVariance")
+
+ @property
+ @since("2.0.0")
+ def meanAbsoluteError(self):
+ """
+ Returns the mean absolute error, which is a risk function
+ corresponding to the expected value of the absolute error
+ loss or l1-norm loss.
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LinearRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("meanAbsoluteError")
+
+ @property
+ @since("2.0.0")
+ def meanSquaredError(self):
+ """
+ Returns the mean squared error, which is a risk function
+ corresponding to the expected value of the squared error
+ loss or quadratic loss.
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LinearRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("meanSquaredError")
+
+ @property
+ @since("2.0.0")
+ def rootMeanSquaredError(self):
+ """
+ Returns the root mean squared error, which is defined as the
+ square root of the mean squared error.
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LinearRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("rootMeanSquaredError")
+
+ @property
+ @since("2.0.0")
+ def r2(self):
+ """
+ Returns R^2^, the coefficient of determination.
+ Reference: http://en.wikipedia.org/wiki/Coefficient_of_determination
+
+ Note: This ignores instance weights (setting all to 1.0) from
+ `LinearRegression.weightCol`. This will change in later Spark
+ versions.
+ """
+ return self._call_java("r2")
+
+ @property
+ @since("2.0.0")
+ def residuals(self):
+ """
+ Residuals (label - predicted value)
+ """
+ return self._call_java("residuals")
+
+ @property
+ @since("2.0.0")
+ def numInstances(self):
+ """
+ Number of instances in DataFrame predictions
+ """
+ return self._call_java("numInstances")
+
+ @property
+ @since("2.0.0")
+ def devianceResiduals(self):
+ """
+ The weighted residuals, the usual residuals rescaled by the
+ square root of the instance weights.
+ """
+ return self._call_java("devianceResiduals")
+
+ @property
+ @since("2.0.0")
+ def coefficientStandardErrors(self):
+ """
+ Standard error of estimated coefficients and intercept.
+ This value is only available when using the "normal" solver.
+
+ If :py:attr:`LinearRegression.fitIntercept` is set to True,
+ then the last element returned corresponds to the intercept.
+
+ .. seealso:: :py:attr:`LinearRegression.solver`
+ """
+ return self._call_java("coefficientStandardErrors")
+
+ @property
+ @since("2.0.0")
+ def tValues(self):
+ """
+ T-statistic of estimated coefficients and intercept.
+ This value is only available when using the "normal" solver.
+
+ If :py:attr:`LinearRegression.fitIntercept` is set to True,
+ then the last element returned corresponds to the intercept.
+
+ .. seealso:: :py:attr:`LinearRegression.solver`
+ """
+ return self._call_java("tValues")
+
+ @property
+ @since("2.0.0")
+ def pValues(self):
+ """
+ Two-sided p-value of estimated coefficients and intercept.
+ This value is only available when using the "normal" solver.
+
+ If :py:attr:`LinearRegression.fitIntercept` is set to True,
+ then the last element returned corresponds to the intercept.
+
+ .. seealso:: :py:attr:`LinearRegression.solver`
+ """
+ return self._call_java("pValues")
+
+
+@inherit_doc
+class LinearRegressionTrainingSummary(LinearRegressionSummary):
+ """
+ .. note:: Experimental
+
+ Linear regression training results. Currently, the training summary ignores the
+ training weights except for the objective trace.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @property
+ @since("2.0.0")
+ def objectiveHistory(self):
+ """
+ Objective function (scaled loss + regularization) at each
+ iteration.
+ This value is only available when using the "l-bfgs" solver.
+
+ .. seealso:: :py:attr:`LinearRegression.solver`
+ """
+ return self._call_java("objectiveHistory")
+
+ @property
+ @since("2.0.0")
+ def totalIterations(self):
+ """
+ Number of training iterations until termination.
+ This value is only available when using the "l-bfgs" solver.
+
+ .. seealso:: :py:attr:`LinearRegression.solver`
+ """
+ return self._call_java("totalIterations")
+
@inherit_doc
class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
@@ -389,7 +640,7 @@ class GBTParams(TreeEnsembleParams):
@inherit_doc
class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval,
- HasSeed, JavaMLWritable, JavaMLReadable):
+ HasSeed, JavaMLWritable, JavaMLReadable, HasVarianceCol):
"""
`http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree`
learning algorithm for regression.
@@ -399,7 +650,7 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
>>> df = sqlContext.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
- >>> dt = DecisionTreeRegressor(maxDepth=2)
+ >>> dt = DecisionTreeRegressor(maxDepth=2, varianceCol="variance")
>>> model = dt.fit(df)
>>> model.depth
1
@@ -425,6 +676,8 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
True
>>> model.depth == model2.depth
True
+ >>> model.transform(test1).head().variance
+ 0.0
.. versionadded:: 1.4.0
"""
@@ -433,12 +686,12 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance",
- seed=None):
+ seed=None, varianceCol=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
- impurity="variance", seed=None)
+ impurity="variance", seed=None, varianceCol=None)
"""
super(DecisionTreeRegressor, self).__init__()
self._java_obj = self._new_java_obj(
@@ -454,12 +707,12 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
- impurity="variance", seed=None):
+ impurity="variance", seed=None, varianceCol=None):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
- impurity="variance", seed=None)
+ impurity="variance", seed=None, varianceCol=None)
Sets params for the DecisionTreeRegressor.
"""
kwargs = self.setParams._input_kwargs
@@ -533,7 +786,7 @@ class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReada
- Normalize importances for tree to sum to 1.
Note: Feature importance for single decision trees can have high variance due to
- correlated predictor variables. Consider using a :class:`RandomForestRegressor`
+ correlated predictor variables. Consider using a :py:class:`RandomForestRegressor`
to determine feature importance instead.
"""
return self._call_java("featureImportances")
@@ -541,7 +794,8 @@ class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReada
@inherit_doc
class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed,
- RandomForestParams, TreeRegressorParams, HasCheckpointInterval):
+ RandomForestParams, TreeRegressorParams, HasCheckpointInterval,
+ JavaMLWritable, JavaMLReadable):
"""
`http://en.wikipedia.org/wiki/Random_forest Random Forest`
learning algorithm for regression.
@@ -564,6 +818,16 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
0.5
+ >>> rfr_path = temp_path + "/rfr"
+ >>> rf.save(rfr_path)
+ >>> rf2 = RandomForestRegressor.load(rfr_path)
+ >>> rf2.getNumTrees()
+ 2
+ >>> model_path = temp_path + "/rfr_model"
+ >>> model.save(model_path)
+ >>> model2 = RandomForestRegressionModel.load(model_path)
+ >>> model.featureImportances == model2.featureImportances
+ True
.. versionadded:: 1.4.0
"""
@@ -613,7 +877,7 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
return RandomForestRegressionModel(java_model)
-class RandomForestRegressionModel(TreeEnsembleModels):
+class RandomForestRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable):
"""
Model fitted by RandomForestRegressor.
@@ -626,16 +890,12 @@ class RandomForestRegressionModel(TreeEnsembleModels):
"""
Estimate of the importance of each feature.
- This generalizes the idea of "Gini" importance to other losses,
- following the explanation of Gini importance from "Random Forests" documentation
- by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+ Each feature's importance is the average of its importance across all trees in the ensemble
+ The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ and follows the implementation from scikit-learn.
- This feature importance is calculated as follows:
- - Average over trees:
- - importance(feature j) = sum (over nodes which split on feature j) of the gain,
- where gain is scaled by the number of instances passing through node
- - Normalize importances for tree to sum to 1.
- - Normalize feature importance vector to sum to 1.
+ .. seealso:: :py:attr:`DecisionTreeRegressionModel.featureImportances`
"""
return self._call_java("featureImportances")
@@ -655,6 +915,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> gbt = GBTRegressor(maxIter=5, maxDepth=2, seed=42)
>>> model = gbt.fit(df)
+ >>> model.featureImportances
+ SparseVector(1, {0: 1.0})
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
@@ -734,6 +996,21 @@ class GBTRegressionModel(TreeEnsembleModels):
.. versionadded:: 1.4.0
"""
+ @property
+ @since("2.0.0")
+ def featureImportances(self):
+ """
+ Estimate of the importance of each feature.
+
+ Each feature's importance is the average of its importance across all trees in the ensemble
+ The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+ (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+ and follows the implementation from scikit-learn.
+
+ .. seealso:: :py:attr:`DecisionTreeRegressionModel.featureImportances`
+ """
+ return self._call_java("featureImportances")
+
@inherit_doc
class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
@@ -921,6 +1198,150 @@ class AFTSurvivalRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
return self._call_java("predict", features)
+@inherit_doc
+class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, HasPredictionCol,
+ HasFitIntercept, HasMaxIter, HasTol, HasRegParam, HasWeightCol,
+ HasSolver, JavaMLWritable, JavaMLReadable):
+ """
+ Generalized Linear Regression.
+
+ Fit a Generalized Linear Model specified by giving a symbolic description of the linear
+ predictor (link function) and a description of the error distribution (family). It supports
+ "gaussian", "binomial", "poisson" and "gamma" as family. Valid link functions for each family
+ is listed below. The first link function of each family is the default one.
+ - "gaussian" -> "identity", "log", "inverse"
+ - "binomial" -> "logit", "probit", "cloglog"
+ - "poisson" -> "log", "identity", "sqrt"
+ - "gamma" -> "inverse", "identity", "log"
+
+ .. seealso:: `GLM <https://en.wikipedia.org/wiki/Generalized_linear_model>`_
+
+ >>> from pyspark.mllib.linalg import Vectors
+ >>> df = sqlContext.createDataFrame([
+ ... (1.0, Vectors.dense(0.0, 0.0)),
+ ... (1.0, Vectors.dense(1.0, 2.0)),
+ ... (2.0, Vectors.dense(0.0, 0.0)),
+ ... (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"])
+ >>> glr = GeneralizedLinearRegression(family="gaussian", link="identity")
+ >>> model = glr.fit(df)
+ >>> abs(model.transform(df).head().prediction - 1.5) < 0.001
+ True
+ >>> model.coefficients
+ DenseVector([1.5..., -1.0...])
+ >>> abs(model.intercept - 1.5) < 0.001
+ True
+ >>> glr_path = temp_path + "/glr"
+ >>> glr.save(glr_path)
+ >>> glr2 = GeneralizedLinearRegression.load(glr_path)
+ >>> glr.getFamily() == glr2.getFamily()
+ True
+ >>> model_path = temp_path + "/glr_model"
+ >>> model.save(model_path)
+ >>> model2 = GeneralizedLinearRegressionModel.load(model_path)
+ >>> model.intercept == model2.intercept
+ True
+ >>> model.coefficients[0] == model2.coefficients[0]
+ True
+
+ .. versionadded:: 2.0.0
+ """
+
+ family = Param(Params._dummy(), "family", "The name of family which is a description of " +
+ "the error distribution to be used in the model. Supported options: " +
+ "gaussian(default), binomial, poisson and gamma.")
+ link = Param(Params._dummy(), "link", "The name of link function which provides the " +
+ "relationship between the linear predictor and the mean of the distribution " +
+ "function. Supported options: identity, log, inverse, logit, probit, cloglog " +
+ "and sqrt.")
+
+ @keyword_only
+ def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",
+ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
+ regParam=0.0, weightCol=None, solver="irls"):
+ """
+ __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
+ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
+ regParam=0.0, weightCol=None, solver="irls")
+ """
+ super(GeneralizedLinearRegression, self).__init__()
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid)
+ self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls")
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ @since("2.0.0")
+ def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction",
+ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
+ regParam=0.0, weightCol=None, solver="irls"):
+ """
+ setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
+ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
+ regParam=0.0, weightCol=None, solver="irls")
+ Sets params for generalized linear regression.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
+
+ def _create_model(self, java_model):
+ return GeneralizedLinearRegressionModel(java_model)
+
+ @since("2.0.0")
+ def setFamily(self, value):
+ """
+ Sets the value of :py:attr:`family`.
+ """
+ self._paramMap[self.family] = value
+ return self
+
+ @since("2.0.0")
+ def getFamily(self):
+ """
+ Gets the value of family or its default value.
+ """
+ return self.getOrDefault(self.family)
+
+ @since("2.0.0")
+ def setLink(self, value):
+ """
+ Sets the value of :py:attr:`link`.
+ """
+ self._paramMap[self.link] = value
+ return self
+
+ @since("2.0.0")
+ def getLink(self):
+ """
+ Gets the value of link or its default value.
+ """
+ return self.getOrDefault(self.link)
+
+
+class GeneralizedLinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
+ """
+ Model fitted by GeneralizedLinearRegression.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @property
+ @since("2.0.0")
+ def coefficients(self):
+ """
+ Model coefficients.
+ """
+ return self._call_java("coefficients")
+
+ @property
+ @since("2.0.0")
+ def intercept(self):
+ """
+ Model intercept.
+ """
+ return self._call_java("intercept")
+
+
if __name__ == "__main__":
import doctest
import pyspark.ml.regression
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 224232ed7f..86c0254a2b 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -44,15 +44,16 @@ import numpy as np
from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier
from pyspark.ml.clustering import KMeans
-from pyspark.ml.evaluation import RegressionEvaluator
+from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator
from pyspark.ml.feature import *
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor
from pyspark.ml.tuning import *
from pyspark.ml.util import keyword_only
-from pyspark.ml.wrapper import JavaWrapper
-from pyspark.mllib.linalg import DenseVector, SparseVector
+from pyspark.ml.util import MLWritable, MLWriter
+from pyspark.ml.wrapper import JavaParams
+from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector
from pyspark.sql import DataFrame, SQLContext, Row
from pyspark.sql.functions import rand
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
@@ -238,6 +239,17 @@ class OtherTestParams(HasMaxIter, HasInputCol, HasSeed):
return self._set(**kwargs)
+class HasThrowableProperty(Params):
+
+ def __init__(self):
+ super(HasThrowableProperty, self).__init__()
+ self.p = Param(self, "none", "empty param")
+
+ @property
+ def test_property(self):
+ raise RuntimeError("Test property to raise error when invoked")
+
+
class ParamTests(PySparkTestCase):
def test_copy_new_parent(self):
@@ -394,6 +406,22 @@ class FeatureTests(PySparkTestCase):
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, ["a"])
+ def test_count_vectorizer_with_binary(self):
+ sqlContext = SQLContext(self.sc)
+ dataset = sqlContext.createDataFrame([
+ (0, "a a a b b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),),
+ (1, "a a".split(' '), SparseVector(3, {0: 1.0}),),
+ (2, "a b".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),),
+ (3, "c".split(' '), SparseVector(3, {2: 1.0}),)], ["id", "words", "expected"])
+ cv = CountVectorizer(binary=True, inputCol="words", outputCol="features")
+ model = cv.fit(dataset)
+
+ transformedList = model.transform(dataset).select("features", "expected").collect()
+
+ for r in transformedList:
+ feature, expected = r
+ self.assertEqual(feature, expected)
+
class HasInducedError(Params):
@@ -478,6 +506,32 @@ class CrossValidatorTests(PySparkTestCase):
"Best model should have zero induced error")
self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
+ def test_save_load(self):
+ temp_path = tempfile.mkdtemp()
+ sqlContext = SQLContext(self.sc)
+ dataset = sqlContext.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+ lr = LogisticRegression()
+ grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+ evaluator = BinaryClassificationEvaluator()
+ cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
+ cvModel = cv.fit(dataset)
+ cvPath = temp_path + "/cv"
+ cv.save(cvPath)
+ loadedCV = CrossValidator.load(cvPath)
+ self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid)
+ self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid)
+ self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps())
+ cvModelPath = temp_path + "/cvModel"
+ cvModel.save(cvModelPath)
+ loadedModel = CrossValidatorModel.load(cvModelPath)
+ self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid)
+
class TrainValidationSplitTests(PySparkTestCase):
@@ -529,6 +583,32 @@ class TrainValidationSplitTests(PySparkTestCase):
"Best model should have zero induced error")
self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
+ def test_save_load(self):
+ temp_path = tempfile.mkdtemp()
+ sqlContext = SQLContext(self.sc)
+ dataset = sqlContext.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+ lr = LogisticRegression()
+ grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+ evaluator = BinaryClassificationEvaluator()
+ tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
+ tvsModel = tvs.fit(dataset)
+ tvsPath = temp_path + "/tvs"
+ tvs.save(tvsPath)
+ loadedTvs = TrainValidationSplit.load(tvsPath)
+ self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
+ self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid)
+ self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps())
+ tvsModelPath = temp_path + "/tvsModel"
+ tvsModel.save(tvsModelPath)
+ loadedModel = TrainValidationSplitModel.load(tvsModelPath)
+ self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid)
+
class PersistenceTest(PySparkTestCase):
@@ -580,7 +660,7 @@ class PersistenceTest(PySparkTestCase):
"""
self.assertEqual(m1.uid, m2.uid)
self.assertEqual(type(m1), type(m2))
- if isinstance(m1, JavaWrapper):
+ if isinstance(m1, JavaParams):
self.assertEqual(len(m1.params), len(m2.params))
for p in m1.params:
self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p))
@@ -655,6 +735,10 @@ class PersistenceTest(PySparkTestCase):
except OSError:
pass
+ def test_write_property(self):
+ lr = LinearRegression(maxIter=1)
+ self.assertTrue(isinstance(lr.write, MLWriter))
+
def test_decisiontree_classifier(self):
dt = DecisionTreeClassifier(maxDepth=1)
path = tempfile.mkdtemp()
@@ -692,15 +776,94 @@ class PersistenceTest(PySparkTestCase):
pass
-class HasThrowableProperty(Params):
+class TrainingSummaryTest(PySparkTestCase):
- def __init__(self):
- super(HasThrowableProperty, self).__init__()
- self.p = Param(self, "none", "empty param")
+ def test_linear_regression_summary(self):
+ from pyspark.mllib.linalg import Vectors
+ sqlContext = SQLContext(self.sc)
+ df = sqlContext.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+ (0.0, 2.0, Vectors.sparse(1, [], []))],
+ ["label", "weight", "features"])
+ lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight",
+ fitIntercept=False)
+ model = lr.fit(df)
+ self.assertTrue(model.hasSummary)
+ s = model.summary
+ # test that api is callable and returns expected types
+ self.assertGreater(s.totalIterations, 0)
+ self.assertTrue(isinstance(s.predictions, DataFrame))
+ self.assertEqual(s.predictionCol, "prediction")
+ self.assertEqual(s.labelCol, "label")
+ self.assertEqual(s.featuresCol, "features")
+ objHist = s.objectiveHistory
+ self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float))
+ self.assertAlmostEqual(s.explainedVariance, 0.25, 2)
+ self.assertAlmostEqual(s.meanAbsoluteError, 0.0)
+ self.assertAlmostEqual(s.meanSquaredError, 0.0)
+ self.assertAlmostEqual(s.rootMeanSquaredError, 0.0)
+ self.assertAlmostEqual(s.r2, 1.0, 2)
+ self.assertTrue(isinstance(s.residuals, DataFrame))
+ self.assertEqual(s.numInstances, 2)
+ devResiduals = s.devianceResiduals
+ self.assertTrue(isinstance(devResiduals, list) and isinstance(devResiduals[0], float))
+ coefStdErr = s.coefficientStandardErrors
+ self.assertTrue(isinstance(coefStdErr, list) and isinstance(coefStdErr[0], float))
+ tValues = s.tValues
+ self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], float))
+ pValues = s.pValues
+ self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], float))
+ # test evaluation (with training dataset) produces a summary with same values
+ # one check is enough to verify a summary is returned, Scala version runs full test
+ sameSummary = model.evaluate(df)
+ self.assertAlmostEqual(sameSummary.explainedVariance, s.explainedVariance)
+
+ def test_logistic_regression_summary(self):
+ from pyspark.mllib.linalg import Vectors
+ sqlContext = SQLContext(self.sc)
+ df = sqlContext.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+ (0.0, 2.0, Vectors.sparse(1, [], []))],
+ ["label", "weight", "features"])
+ lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False)
+ model = lr.fit(df)
+ self.assertTrue(model.hasSummary)
+ s = model.summary
+ # test that api is callable and returns expected types
+ self.assertTrue(isinstance(s.predictions, DataFrame))
+ self.assertEqual(s.probabilityCol, "probability")
+ self.assertEqual(s.labelCol, "label")
+ self.assertEqual(s.featuresCol, "features")
+ objHist = s.objectiveHistory
+ self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float))
+ self.assertGreater(s.totalIterations, 0)
+ self.assertTrue(isinstance(s.roc, DataFrame))
+ self.assertAlmostEqual(s.areaUnderROC, 1.0, 2)
+ self.assertTrue(isinstance(s.pr, DataFrame))
+ self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame))
+ self.assertTrue(isinstance(s.precisionByThreshold, DataFrame))
+ self.assertTrue(isinstance(s.recallByThreshold, DataFrame))
+ # test evaluation (with training dataset) produces a summary with same values
+ # one check is enough to verify a summary is returned, Scala version runs full test
+ sameSummary = model.evaluate(df)
+ self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)
+
+
+class HashingTFTest(PySparkTestCase):
+
+ def test_apply_binary_term_freqs(self):
+ sqlContext = SQLContext(self.sc)
- @property
- def test_property(self):
- raise RuntimeError("Test property to raise error when invoked")
+ df = sqlContext.createDataFrame([(0, ["a", "a", "b", "c", "c", "c"])], ["id", "words"])
+ n = 100
+ hashingTF = HashingTF()
+ hashingTF.setInputCol("words").setOutputCol("features").setNumFeatures(n).setBinary(True)
+ output = hashingTF.transform(df)
+ features = output.select("features").first().features.toArray()
+ expected = Vectors.sparse(n, {(ord("a") % n): 1.0,
+ (ord("b") % n): 1.0,
+ (ord("c") % n): 1.0}).toArray()
+ for i in range(0, n):
+ self.assertAlmostEqual(features[i], expected[i], 14, "Error at " + str(i) +
+ ": expected " + str(expected[i]) + ", got " + str(features[i]))
if __name__ == "__main__":
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index a528d22e18..456d79d897 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -18,12 +18,15 @@
import itertools
import numpy as np
+from pyspark import SparkContext
from pyspark import since
from pyspark.ml import Estimator, Model
from pyspark.ml.param import Params, Param, TypeConverters
from pyspark.ml.param.shared import HasSeed
-from pyspark.ml.util import keyword_only
+from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable
+from pyspark.ml.wrapper import JavaParams
from pyspark.sql.functions import rand
+from pyspark.mllib.common import inherit_doc, _py2java
__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit',
'TrainValidationSplitModel']
@@ -91,7 +94,84 @@ class ParamGridBuilder(object):
return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)]
-class CrossValidator(Estimator, HasSeed):
+class ValidatorParams(HasSeed):
+ """
+ Common params for TrainValidationSplit and CrossValidator.
+ """
+
+ estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated")
+ estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps")
+ evaluator = Param(
+ Params._dummy(), "evaluator",
+ "evaluator used to select hyper-parameters that maximize the validator metric")
+
+ def setEstimator(self, value):
+ """
+ Sets the value of :py:attr:`estimator`.
+ """
+ return self._set(estimator=value)
+
+ def getEstimator(self):
+ """
+ Gets the value of estimator or its default value.
+ """
+ return self.getOrDefault(self.estimator)
+
+ def setEstimatorParamMaps(self, value):
+ """
+ Sets the value of :py:attr:`estimatorParamMaps`.
+ """
+ return self._set(estimatorParamMaps=value)
+
+ def getEstimatorParamMaps(self):
+ """
+ Gets the value of estimatorParamMaps or its default value.
+ """
+ return self.getOrDefault(self.estimatorParamMaps)
+
+ def setEvaluator(self, value):
+ """
+ Sets the value of :py:attr:`evaluator`.
+ """
+ return self._set(evaluator=value)
+
+ def getEvaluator(self):
+ """
+ Gets the value of evaluator or its default value.
+ """
+ return self.getOrDefault(self.evaluator)
+
+ @classmethod
+ def _from_java_impl(cls, java_stage):
+ """
+ Return Python estimator, estimatorParamMaps, and evaluator from a Java ValidatorParams.
+ """
+
+ # Load information from java_stage to the instance.
+ estimator = JavaParams._from_java(java_stage.getEstimator())
+ evaluator = JavaParams._from_java(java_stage.getEvaluator())
+ epms = [estimator._transfer_param_map_from_java(epm)
+ for epm in java_stage.getEstimatorParamMaps()]
+ return estimator, epms, evaluator
+
+ def _to_java_impl(self):
+ """
+ Return Java estimator, estimatorParamMaps, and evaluator from this Python instance.
+ """
+
+ gateway = SparkContext._gateway
+ cls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap
+
+ java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps()))
+ for idx, epm in enumerate(self.getEstimatorParamMaps()):
+ java_epms[idx] = self.getEstimator()._transfer_param_map_to_java(epm)
+
+ java_estimator = self.getEstimator()._to_java()
+ java_evaluator = self.getEvaluator()._to_java()
+ return java_estimator, java_epms, java_evaluator
+
+
+class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable):
"""
K-fold cross validation.
@@ -116,11 +196,6 @@ class CrossValidator(Estimator, HasSeed):
.. versionadded:: 1.4.0
"""
- estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated")
- estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps")
- evaluator = Param(
- Params._dummy(), "evaluator",
- "evaluator used to select hyper-parameters that maximize the cross-validated metric")
numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation",
typeConverter=TypeConverters.toInt)
@@ -149,51 +224,6 @@ class CrossValidator(Estimator, HasSeed):
return self._set(**kwargs)
@since("1.4.0")
- def setEstimator(self, value):
- """
- Sets the value of :py:attr:`estimator`.
- """
- self._paramMap[self.estimator] = value
- return self
-
- @since("1.4.0")
- def getEstimator(self):
- """
- Gets the value of estimator or its default value.
- """
- return self.getOrDefault(self.estimator)
-
- @since("1.4.0")
- def setEstimatorParamMaps(self, value):
- """
- Sets the value of :py:attr:`estimatorParamMaps`.
- """
- self._paramMap[self.estimatorParamMaps] = value
- return self
-
- @since("1.4.0")
- def getEstimatorParamMaps(self):
- """
- Gets the value of estimatorParamMaps or its default value.
- """
- return self.getOrDefault(self.estimatorParamMaps)
-
- @since("1.4.0")
- def setEvaluator(self, value):
- """
- Sets the value of :py:attr:`evaluator`.
- """
- self._paramMap[self.evaluator] = value
- return self
-
- @since("1.4.0")
- def getEvaluator(self):
- """
- Gets the value of evaluator or its default value.
- """
- return self.getOrDefault(self.evaluator)
-
- @since("1.4.0")
def setNumFolds(self, value):
"""
Sets the value of :py:attr:`numFolds`.
@@ -236,7 +266,7 @@ class CrossValidator(Estimator, HasSeed):
else:
bestIndex = np.argmin(metrics)
bestModel = est.fit(dataset, epm[bestIndex])
- return CrossValidatorModel(bestModel)
+ return self._copyValues(CrossValidatorModel(bestModel))
@since("1.4.0")
def copy(self, extra=None):
@@ -258,8 +288,58 @@ class CrossValidator(Estimator, HasSeed):
newCV.setEvaluator(self.getEvaluator().copy(extra))
return newCV
+ @since("2.0.0")
+ def write(self):
+ """Returns an MLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+ @since("2.0.0")
+ def save(self, path):
+ """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
+ self.write().save(path)
+
+ @classmethod
+ @since("2.0.0")
+ def read(cls):
+ """Returns an MLReader instance for this class."""
+ return JavaMLReader(cls)
+
+ @classmethod
+ def _from_java(cls, java_stage):
+ """
+ Given a Java CrossValidator, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+
+ estimator, epms, evaluator = super(CrossValidator, cls)._from_java_impl(java_stage)
+ numFolds = java_stage.getNumFolds()
+ seed = java_stage.getSeed()
+ # Create a new instance of this stage.
+ py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator,
+ numFolds=numFolds, seed=seed)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java CrossValidator. Used for ML persistence.
+
+ :return: Java object equivalent to this instance.
+ """
+
+ estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl()
+
+ _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid)
+ _java_obj.setEstimatorParamMaps(epms)
+ _java_obj.setEvaluator(evaluator)
+ _java_obj.setEstimator(estimator)
+ _java_obj.setSeed(self.getSeed())
+ _java_obj.setNumFolds(self.getNumFolds())
+
+ return _java_obj
+
-class CrossValidatorModel(Model):
+class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable):
"""
Model from k-fold cross validation.
@@ -289,8 +369,60 @@ class CrossValidatorModel(Model):
extra = dict()
return CrossValidatorModel(self.bestModel.copy(extra))
+ @since("2.0.0")
+ def write(self):
+ """Returns an MLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+ @since("2.0.0")
+ def save(self, path):
+ """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
+ self.write().save(path)
-class TrainValidationSplit(Estimator, HasSeed):
+ @classmethod
+ @since("2.0.0")
+ def read(cls):
+ """Returns an MLReader instance for this class."""
+ return JavaMLReader(cls)
+
+ @classmethod
+ def _from_java(cls, java_stage):
+ """
+ Given a Java CrossValidatorModel, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+
+ # Load information from java_stage to the instance.
+ bestModel = JavaParams._from_java(java_stage.bestModel())
+ estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage)
+ # Create a new instance of this stage.
+ py_stage = cls(bestModel=bestModel)\
+ .setEstimator(estimator).setEstimatorParamMaps(epms).setEvaluator(evaluator)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java CrossValidatorModel. Used for ML persistence.
+
+ :return: Java object equivalent to this instance.
+ """
+
+ sc = SparkContext._active_spark_context
+
+ _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
+ self.uid,
+ self.bestModel._to_java(),
+ _py2java(sc, []))
+ estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl()
+
+ _java_obj.set("evaluator", evaluator)
+ _java_obj.set("estimator", estimator)
+ _java_obj.set("estimatorParamMaps", epms)
+ return _java_obj
+
+
+class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable):
"""
Train-Validation-Split.
@@ -315,11 +447,6 @@ class TrainValidationSplit(Estimator, HasSeed):
.. versionadded:: 2.0.0
"""
- estimator = Param(Params._dummy(), "estimator", "estimator to be tested")
- estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps")
- evaluator = Param(
- Params._dummy(), "evaluator",
- "evaluator used to select hyper-parameters that maximize the validated metric")
trainRatio = Param(Params._dummy(), "trainRatio", "Param for ratio between train and\
validation data. Must be between 0 and 1.")
@@ -348,51 +475,6 @@ class TrainValidationSplit(Estimator, HasSeed):
return self._set(**kwargs)
@since("2.0.0")
- def setEstimator(self, value):
- """
- Sets the value of :py:attr:`estimator`.
- """
- self._paramMap[self.estimator] = value
- return self
-
- @since("2.0.0")
- def getEstimator(self):
- """
- Gets the value of estimator or its default value.
- """
- return self.getOrDefault(self.estimator)
-
- @since("2.0.0")
- def setEstimatorParamMaps(self, value):
- """
- Sets the value of :py:attr:`estimatorParamMaps`.
- """
- self._paramMap[self.estimatorParamMaps] = value
- return self
-
- @since("2.0.0")
- def getEstimatorParamMaps(self):
- """
- Gets the value of estimatorParamMaps or its default value.
- """
- return self.getOrDefault(self.estimatorParamMaps)
-
- @since("2.0.0")
- def setEvaluator(self, value):
- """
- Sets the value of :py:attr:`evaluator`.
- """
- self._paramMap[self.evaluator] = value
- return self
-
- @since("2.0.0")
- def getEvaluator(self):
- """
- Gets the value of evaluator or its default value.
- """
- return self.getOrDefault(self.evaluator)
-
- @since("2.0.0")
def setTrainRatio(self, value):
"""
Sets the value of :py:attr:`trainRatio`.
@@ -429,7 +511,7 @@ class TrainValidationSplit(Estimator, HasSeed):
else:
bestIndex = np.argmin(metrics)
bestModel = est.fit(dataset, epm[bestIndex])
- return TrainValidationSplitModel(bestModel)
+ return self._copyValues(TrainValidationSplitModel(bestModel))
@since("2.0.0")
def copy(self, extra=None):
@@ -451,10 +533,63 @@ class TrainValidationSplit(Estimator, HasSeed):
newTVS.setEvaluator(self.getEvaluator().copy(extra))
return newTVS
+ @since("2.0.0")
+ def write(self):
+ """Returns an MLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+ @since("2.0.0")
+ def save(self, path):
+ """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
+ self.write().save(path)
+
+ @classmethod
+ @since("2.0.0")
+ def read(cls):
+ """Returns an MLReader instance for this class."""
+ return JavaMLReader(cls)
+
+ @classmethod
+ def _from_java(cls, java_stage):
+ """
+ Given a Java TrainValidationSplit, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+
+ estimator, epms, evaluator = super(TrainValidationSplit, cls)._from_java_impl(java_stage)
+ trainRatio = java_stage.getTrainRatio()
+ seed = java_stage.getSeed()
+ # Create a new instance of this stage.
+ py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator,
+ trainRatio=trainRatio, seed=seed)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java TrainValidationSplit. Used for ML persistence.
+
+ :return: Java object equivalent to this instance.
+ """
+
+ estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl()
+
+ _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit",
+ self.uid)
+ _java_obj.setEstimatorParamMaps(epms)
+ _java_obj.setEvaluator(evaluator)
+ _java_obj.setEstimator(estimator)
+ _java_obj.setTrainRatio(self.getTrainRatio())
+ _java_obj.setSeed(self.getSeed())
+
+ return _java_obj
-class TrainValidationSplitModel(Model):
+
+class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable):
"""
Model from train validation split.
+
+ .. versionadded:: 2.0.0
"""
def __init__(self, bestModel):
@@ -480,19 +615,75 @@ class TrainValidationSplitModel(Model):
extra = dict()
return TrainValidationSplitModel(self.bestModel.copy(extra))
+ @since("2.0.0")
+ def write(self):
+ """Returns an MLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+ @since("2.0.0")
+ def save(self, path):
+ """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
+ self.write().save(path)
+
+ @classmethod
+ @since("2.0.0")
+ def read(cls):
+ """Returns an MLReader instance for this class."""
+ return JavaMLReader(cls)
+
+ @classmethod
+ def _from_java(cls, java_stage):
+ """
+ Given a Java TrainValidationSplitModel, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+
+ # Load information from java_stage to the instance.
+ bestModel = JavaParams._from_java(java_stage.bestModel())
+ estimator, epms, evaluator = \
+ super(TrainValidationSplitModel, cls)._from_java_impl(java_stage)
+ # Create a new instance of this stage.
+ py_stage = cls(bestModel=bestModel)\
+ .setEstimator(estimator).setEstimatorParamMaps(epms).setEvaluator(evaluator)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java TrainValidationSplitModel. Used for ML persistence.
+
+ :return: Java object equivalent to this instance.
+ """
+
+ sc = SparkContext._active_spark_context
+
+ _java_obj = JavaParams._new_java_obj(
+ "org.apache.spark.ml.tuning.TrainValidationSplitModel",
+ self.uid,
+ self.bestModel._to_java(),
+ _py2java(sc, []))
+ estimator, epms, evaluator = super(TrainValidationSplitModel, self)._to_java_impl()
+
+ _java_obj.set("evaluator", evaluator)
+ _java_obj.set("estimator", estimator)
+ _java_obj.set("estimatorParamMaps", epms)
+ return _java_obj
+
+
if __name__ == "__main__":
import doctest
+
from pyspark.context import SparkContext
from pyspark.sql import SQLContext
globs = globals().copy()
+
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
sc = SparkContext("local[2]", "ml.tuning tests")
sqlContext = SQLContext(sc)
globs['sc'] = sc
globs['sqlContext'] = sqlContext
- (failure_count, test_count) = doctest.testmod(
- globs=globs, optionflags=doctest.ELLIPSIS)
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
sc.stop()
if failure_count:
exit(-1)
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 6703851262..9dfcef0e40 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -99,7 +99,7 @@ class MLWriter(object):
@inherit_doc
class JavaMLWriter(MLWriter):
"""
- (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaWrapper` types
+ (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaParams` types
"""
def __init__(self, instance):
@@ -134,13 +134,14 @@ class MLWritable(object):
.. versionadded:: 2.0.0
"""
+ @property
def write(self):
"""Returns an JavaMLWriter instance for this ML instance."""
raise NotImplementedError("MLWritable is not yet implemented for type: %r" % type(self))
def save(self, path):
"""Save this ML instance to the given path, a shortcut of `write().save(path)`."""
- self.write().save(path)
+ self.write.save(path)
@inherit_doc
@@ -149,6 +150,7 @@ class JavaMLWritable(MLWritable):
(Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`.
"""
+ @property
def write(self):
"""Returns an JavaMLWriter instance for this ML instance."""
return JavaMLWriter(self)
@@ -176,7 +178,7 @@ class MLReader(object):
@inherit_doc
class JavaMLReader(MLReader):
"""
- (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaWrapper` types
+ (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaParams` types
"""
def __init__(self, clazz):
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 35b0eba926..cd0e5b80d5 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -25,29 +25,32 @@ from pyspark.ml.util import _jvm
from pyspark.mllib.common import inherit_doc, _java2py, _py2java
-@inherit_doc
-class JavaWrapper(Params):
+class JavaWrapper(object):
"""
- Utility class to help create wrapper classes from Java/Scala
- implementations of pipeline components.
+ Wrapper class for a Java companion object
"""
+ def __init__(self, java_obj=None):
+ super(JavaWrapper, self).__init__()
+ self._java_obj = java_obj
- __metaclass__ = ABCMeta
-
- def __init__(self):
+ @classmethod
+ def _create_from_java_class(cls, java_class, *args):
"""
- Initialize the wrapped java object to None
+ Construct this object from given Java classname and arguments
"""
- super(JavaWrapper, self).__init__()
- #: The wrapped Java companion object. Subclasses should initialize
- #: it properly. The param values in the Java object should be
- #: synced with the Python wrapper in fit/transform/evaluate/copy.
- self._java_obj = None
+ java_obj = JavaWrapper._new_java_obj(java_class, *args)
+ return cls(java_obj)
+
+ def _call_java(self, name, *args):
+ m = getattr(self._java_obj, name)
+ sc = SparkContext._active_spark_context
+ java_args = [_py2java(sc, arg) for arg in args]
+ return _java2py(sc, m(*java_args))
@staticmethod
def _new_java_obj(java_class, *args):
"""
- Construct a new Java object.
+ Returns a new Java object.
"""
sc = SparkContext._active_spark_context
java_obj = _jvm()
@@ -56,6 +59,18 @@ class JavaWrapper(Params):
java_args = [_py2java(sc, arg) for arg in args]
return java_obj(*java_args)
+
+@inherit_doc
+class JavaParams(JavaWrapper, Params):
+ """
+ Utility class to help create wrapper classes from Java/Scala
+ implementations of pipeline components.
+ """
+ #: The param values in the Java object should be
+ #: synced with the Python wrapper in fit/transform/evaluate/copy.
+
+ __metaclass__ = ABCMeta
+
def _make_java_param_pair(self, param, value):
"""
Makes a Java parm pair.
@@ -76,6 +91,17 @@ class JavaWrapper(Params):
pair = self._make_java_param_pair(param, paramMap[param])
self._java_obj.set(pair)
+ def _transfer_param_map_to_java(self, pyParamMap):
+ """
+ Transforms a Python ParamMap into a Java ParamMap.
+ """
+ paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
+ for param in self.params:
+ if param in pyParamMap:
+ pair = self._make_java_param_pair(param, pyParamMap[param])
+ paramMap.put([pair])
+ return paramMap
+
def _transfer_params_from_java(self):
"""
Transforms the embedded params from the companion Java object.
@@ -88,6 +114,18 @@ class JavaWrapper(Params):
value = _java2py(sc, self._java_obj.getOrDefault(java_param))
self._paramMap[param] = value
+ def _transfer_param_map_from_java(self, javaParamMap):
+ """
+ Transforms a Java ParamMap into a Python ParamMap.
+ """
+ sc = SparkContext._active_spark_context
+ paramMap = dict()
+ for pair in javaParamMap.toList():
+ param = pair.param()
+ if self.hasParam(str(param.name())):
+ paramMap[self.getParam(param.name())] = _java2py(sc, pair.value())
+ return paramMap
+
@staticmethod
def _empty_java_param_map():
"""
@@ -128,7 +166,7 @@ class JavaWrapper(Params):
stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark")
# Generate a default new instance from the stage_name class.
py_type = __get_class(stage_name)
- if issubclass(py_type, JavaWrapper):
+ if issubclass(py_type, JavaParams):
# Load information from java_stage to the instance.
py_stage = py_type()
py_stage._java_obj = java_stage
@@ -143,7 +181,7 @@ class JavaWrapper(Params):
@inherit_doc
-class JavaEstimator(Estimator, JavaWrapper):
+class JavaEstimator(JavaParams, Estimator):
"""
Base class for :py:class:`Estimator`s that wrap Java/Scala
implementations.
@@ -176,7 +214,7 @@ class JavaEstimator(Estimator, JavaWrapper):
@inherit_doc
-class JavaTransformer(Transformer, JavaWrapper):
+class JavaTransformer(JavaParams, Transformer):
"""
Base class for :py:class:`Transformer`s that wrap Java/Scala
implementations. Subclasses should ensure they have the transformer Java object
@@ -191,7 +229,7 @@ class JavaTransformer(Transformer, JavaWrapper):
@inherit_doc
-class JavaModel(Model, JavaTransformer):
+class JavaModel(JavaTransformer, Model):
"""
Base class for :py:class:`Model`s that wrap Java/Scala
implementations. Subclasses should inherit this class before
@@ -204,7 +242,7 @@ class JavaModel(Model, JavaTransformer):
"""
Initialize this instance with a Java model object.
Subclasses should call this constructor, initialize params,
- and then call _transformer_params_from_java.
+ and then call _transfer_params_from_java.
This instance can be instantiated without specifying java_model,
it will be assigned after that, but this scenario only used by
@@ -214,9 +252,8 @@ class JavaModel(Model, JavaTransformer):
these wrappers depend on pyspark.ml.util (both directly and via
other ML classes).
"""
- super(JavaModel, self).__init__()
+ super(JavaModel, self).__init__(java_model)
if java_model is not None:
- self._java_obj = java_model
self.uid = java_model.uid()
def copy(self, extra=None):
@@ -236,9 +273,3 @@ class JavaModel(Model, JavaTransformer):
that._java_obj = self._java_obj.copy(self._empty_java_param_map())
that._transfer_params_to_java()
return that
-
- def _call_java(self, name, *args):
- m = getattr(self._java_obj, name)
- sc = SparkContext._active_spark_context
- java_args = [_py2java(sc, arg) for arg in args]
- return _java2py(sc, m(*java_args))
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 6129353525..b3dd2f63a5 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -379,6 +379,17 @@ class HashingTF(object):
"""
def __init__(self, numFeatures=1 << 20):
self.numFeatures = numFeatures
+ self.binary = False
+
+ @since("2.0.0")
+ def setBinary(self, value):
+ """
+ If True, term frequency vector will be binary such that non-zero
+ term counts will be set to 1
+ (default: False)
+ """
+ self.binary = value
+ return self
@since('1.2.0')
def indexOf(self, term):
@@ -398,7 +409,7 @@ class HashingTF(object):
freq = {}
for term in document:
i = self.indexOf(term)
- freq[i] = freq.get(i, 0) + 1.0
+ freq[i] = 1.0 if self.binary else freq.get(i, 0) + 1.0
return Vectors.sparse(self.numFeatures, freq.items())
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 5f515b666c..ac55fbf798 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -58,6 +58,7 @@ from pyspark.mllib.recommendation import Rating
from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
+from pyspark.mllib.feature import HashingTF
from pyspark.mllib.feature import Word2Vec
from pyspark.mllib.feature import IDF
from pyspark.mllib.feature import StandardScaler, ElementwiseProduct
@@ -1583,6 +1584,21 @@ class ALSTests(MLlibTestCase):
self.assertRaises(Py4JJavaError, self.sc._jvm.SerDe.loads, bytearray(ser.dumps(r)))
+class HashingTFTest(MLlibTestCase):
+
+ def test_binary_term_freqs(self):
+ hashingTF = HashingTF(100).setBinary(True)
+ doc = "a a b c c c".split(" ")
+ n = hashingTF.numFeatures
+ output = hashingTF.transform(doc).toArray()
+ expected = Vectors.sparse(n, {hashingTF.indexOf("a"): 1.0,
+ hashingTF.indexOf("b"): 1.0,
+ hashingTF.indexOf("c"): 1.0}).toArray()
+ for i in range(0, n):
+ self.assertAlmostEqual(output[i], expected[i], 14, "Error at " + str(i) +
+ ": expected " + str(expected[i]) + ", got " + str(output[i]))
+
+
if __name__ == "__main__":
from pyspark.mllib.tests import *
if not _have_scipy:
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 37574cea0b..8978f028c5 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -115,7 +115,7 @@ def _parse_memory(s):
2048
"""
units = {'g': 1024, 'm': 1, 't': 1 << 20, 'k': 1.0 / 1024}
- if s[-1] not in units:
+ if s[-1].lower() not in units:
raise ValueError("invalid format: " + s)
return int(float(s[:-1]) * units[s[-1].lower()])
@@ -2299,14 +2299,14 @@ class RDD(object):
"""
Return an iterator that contains all of the elements in this RDD.
The iterator will consume as much memory as the largest partition in this RDD.
+
>>> rdd = sc.parallelize(range(10))
>>> [x for x in rdd.toLocalIterator()]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
"""
- for partition in range(self.getNumPartitions()):
- rows = self.context.runJob(self, lambda x: x, [partition])
- for row in rows:
- yield row
+ with SCCallSiteSync(self.context) as css:
+ port = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd())
+ return _load_from_socket(port, self._jrdd_deserializer)
def _prepare_for_python_RDD(sc, command):
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 4008332c84..11dfcfe13e 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -405,7 +405,7 @@ class SQLContext(object):
>>> sqlContext.createDataFrame(rdd, "boolean").collect() # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
- Py4JJavaError:...
+ Py4JJavaError: ...
"""
if isinstance(data, DataFrame):
raise TypeError("data is already a DataFrame")
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 7a69c4c70c..b4fa836893 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -60,7 +60,7 @@ class DataFrame(object):
people = sqlContext.read.parquet("...")
department = sqlContext.read.parquet("...")
- people.filter(people.age > 30).join(department, people.deptId == department.id)) \
+ people.filter(people.age > 30).join(department, people.deptId == department.id)\
.groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"})
.. note:: Experimental
@@ -241,6 +241,20 @@ class DataFrame(object):
return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
@ignore_unicode_prefix
+ @since(2.0)
+ def toLocalIterator(self):
+ """
+ Returns an iterator that contains all of the rows in this :class:`DataFrame`.
+ The iterator will consume as much memory as the largest partition in this DataFrame.
+
+ >>> list(df.toLocalIterator())
+ [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
+ """
+ with SCCallSiteSync(self._sc) as css:
+ port = self._jdf.toPythonIterator()
+ return _load_from_socket(port, BatchedSerializer(PickleSerializer()))
+
+ @ignore_unicode_prefix
@since(1.3)
def limit(self, num):
"""Limits the result count to the number specified.
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index f5d959ef98..5017ab5b36 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -25,7 +25,7 @@ if sys.version < "3":
from itertools import imap as map
from pyspark import since, SparkContext
-from pyspark.rdd import _wrap_function, ignore_unicode_prefix
+from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
from pyspark.sql.types import StringType
from pyspark.sql.column import Column, _to_java_column, _to_seq
@@ -1053,6 +1053,55 @@ def to_utc_timestamp(timestamp, tz):
return Column(sc._jvm.functions.to_utc_timestamp(_to_java_column(timestamp), tz))
+@since(2.0)
+@ignore_unicode_prefix
+def window(timeColumn, windowDuration, slideDuration=None, startTime=None):
+ """Bucketize rows into one or more time windows given a timestamp specifying column. Window
+ starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window
+ [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in
+ the order of months are not supported.
+
+ The time column must be of TimestampType.
+
+ Durations are provided as strings, e.g. '1 second', '1 day 12 hours', '2 minutes'. Valid
+ interval strings are 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'.
+ If the `slideDuration` is not provided, the windows will be tumbling windows.
+
+ The startTime is the offset with respect to 1970-01-01 00:00:00 UTC with which to start
+ window intervals. For example, in order to have hourly tumbling windows that start 15 minutes
+ past the hour, e.g. 12:15-13:15, 13:15-14:15... provide `startTime` as `15 minutes`.
+
+ The output column will be a struct called 'window' by default with the nested columns 'start'
+ and 'end', where 'start' and 'end' will be of `TimestampType`.
+
+ >>> df = sqlContext.createDataFrame([("2016-03-11 09:00:07", 1)]).toDF("date", "val")
+ >>> w = df.groupBy(window("date", "5 seconds")).agg(sum("val").alias("sum"))
+ >>> w.select(w.window.start.cast("string").alias("start"),
+ ... w.window.end.cast("string").alias("end"), "sum").collect()
+ [Row(start=u'2016-03-11 09:00:05', end=u'2016-03-11 09:00:10', sum=1)]
+ """
+ def check_string_field(field, fieldName):
+ if not field or type(field) is not str:
+ raise TypeError("%s should be provided as a string" % fieldName)
+
+ sc = SparkContext._active_spark_context
+ time_col = _to_java_column(timeColumn)
+ check_string_field(windowDuration, "windowDuration")
+ if slideDuration and startTime:
+ check_string_field(slideDuration, "slideDuration")
+ check_string_field(startTime, "startTime")
+ res = sc._jvm.functions.window(time_col, windowDuration, slideDuration, startTime)
+ elif slideDuration:
+ check_string_field(slideDuration, "slideDuration")
+ res = sc._jvm.functions.window(time_col, windowDuration, slideDuration)
+ elif startTime:
+ check_string_field(startTime, "startTime")
+ res = sc._jvm.functions.window(time_col, windowDuration, windowDuration, startTime)
+ else:
+ res = sc._jvm.functions.window(time_col, windowDuration)
+ return Column(res)
+
+
# ---------------------------- misc functions ----------------------------------
@since(1.5)
@@ -1648,6 +1697,13 @@ def sort_array(col, asc=True):
# ---------------------------- User Defined Function ----------------------------------
+def _wrap_function(sc, func, returnType):
+ command = (func, returnType)
+ pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
+ return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
+ sc.pythonVer, broadcast_vars, sc._javaAccumulator)
+
+
class UserDefinedFunction(object):
"""
User defined function in Python
@@ -1662,14 +1718,12 @@ class UserDefinedFunction(object):
def _create_judf(self, name):
from pyspark.sql import SQLContext
- f, returnType = self.func, self.returnType # put them in closure `func`
- func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it)
- ser = AutoBatchedSerializer(PickleSerializer())
sc = SparkContext.getOrCreate()
- wrapped_func = _wrap_function(sc, func, ser, ser)
+ wrapped_func = _wrap_function(sc, self.func, self.returnType)
ctx = SQLContext.getOrCreate(sc)
jdt = ctx._ssql_ctx.parseDataType(self.returnType.json())
if name is None:
+ f = self.func
name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
name, wrapped_func, jdt)
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index cca57a385c..0cef37e57c 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -152,8 +152,8 @@ class DataFrameReader(object):
You can set the following JSON-specific options to deal with non-standard JSON files:
* ``primitivesAsString`` (default ``false``): infers all primitive values as a string \
type
- * `floatAsBigDecimal` (default `false`): infers all floating-point values as a decimal \
- type
+ * `prefersDecimal` (default `false`): infers all floating-point values as a decimal \
+ type. If the values do not fit in decimal, then it infers them as doubles.
* ``allowComments`` (default ``false``): ignores Java/C++ style comment in JSON records
* ``allowUnquotedFieldNames`` (default ``false``): allows unquoted JSON field names
* ``allowSingleQuotes`` (default ``true``): allows single quotes in addition to double \
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 83ef76c13c..e4f79c911c 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -51,7 +51,7 @@ from pyspark.sql.types import UserDefinedType, _infer_type
from pyspark.tests import ReusedPySparkTestCase
from pyspark.sql.functions import UserDefinedFunction, sha2
from pyspark.sql.window import Window
-from pyspark.sql.utils import AnalysisException, IllegalArgumentException
+from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException
class UTCOffsetTimezone(datetime.tzinfo):
@@ -305,6 +305,25 @@ class SQLTests(ReusedPySparkTestCase):
[res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
self.assertEqual(4, res[0])
+ def test_chained_udf(self):
+ self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType())
+ [row] = self.sqlCtx.sql("SELECT double(1)").collect()
+ self.assertEqual(row[0], 2)
+ [row] = self.sqlCtx.sql("SELECT double(double(1))").collect()
+ self.assertEqual(row[0], 4)
+ [row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect()
+ self.assertEqual(row[0], 6)
+
+ def test_multiple_udfs(self):
+ self.sqlCtx.registerFunction("double", lambda x: x * 2, IntegerType())
+ [row] = self.sqlCtx.sql("SELECT double(1), double(2)").collect()
+ self.assertEqual(tuple(row), (2, 4))
+ [row] = self.sqlCtx.sql("SELECT double(double(1)), double(double(2) + 2)").collect()
+ self.assertEqual(tuple(row), (4, 12))
+ self.sqlCtx.registerFunction("add", lambda x, y: x + y, IntegerType())
+ [row] = self.sqlCtx.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect()
+ self.assertEqual(tuple(row), (6, 5))
+
def test_udf_with_array_type(self):
d = [Row(l=list(range(3)), d={"key": list(range(5))})]
rdd = self.sc.parallelize(d)
@@ -324,6 +343,15 @@ class SQLTests(ReusedPySparkTestCase):
[res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
self.assertEqual("", res[0])
+ def test_udf_with_aggregate_function(self):
+ df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
+ from pyspark.sql.functions import udf, col
+ from pyspark.sql.types import BooleanType
+
+ my_filter = udf(lambda a: a == 1, BooleanType())
+ sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
+ self.assertEqual(sel.collect(), [Row(key=1)])
+
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.sqlCtx.read.json(rdd)
@@ -1130,7 +1158,9 @@ class SQLTests(ReusedPySparkTestCase):
def test_capture_analysis_exception(self):
self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc"))
self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b"))
- self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("abc"))
+
+ def test_capture_parse_exception(self):
+ self.assertRaises(ParseException, lambda: self.sqlCtx.sql("abc"))
def test_capture_illegalargument_exception(self):
self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks",
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index b0a0373372..7ea0e0d5c9 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -33,6 +33,12 @@ class AnalysisException(CapturedException):
"""
+class ParseException(CapturedException):
+ """
+ Failed to parse a SQL command.
+ """
+
+
class IllegalArgumentException(CapturedException):
"""
Passed an illegal or inappropriate argument.
@@ -49,6 +55,8 @@ def capture_sql_exception(f):
e.java_exception.getStackTrace()))
if s.startswith('org.apache.spark.sql.AnalysisException: '):
raise AnalysisException(s.split(': ', 1)[1], stackTrace)
+ if s.startswith('org.apache.spark.sql.catalyst.parser.ParseException: '):
+ raise ParseException(s.split(': ', 1)[1], stackTrace)
if s.startswith('java.lang.IllegalArgumentException: '):
raise IllegalArgumentException(s.split(': ', 1)[1], stackTrace)
raise
diff --git a/python/pyspark/storagelevel.py b/python/pyspark/storagelevel.py
index d4f184a85d..ef012d27cb 100644
--- a/python/pyspark/storagelevel.py
+++ b/python/pyspark/storagelevel.py
@@ -44,7 +44,7 @@ class StorageLevel(object):
result = ""
result += "Disk " if self.useDisk else ""
result += "Memory " if self.useMemory else ""
- result += "Tachyon " if self.useOffHeap else ""
+ result += "OffHeap " if self.useOffHeap else ""
result += "Deserialized " if self.deserialized else "Serialized "
result += "%sx Replicated" % self.replication
return result
@@ -55,7 +55,7 @@ StorageLevel.MEMORY_ONLY = StorageLevel(False, True, False, False)
StorageLevel.MEMORY_ONLY_2 = StorageLevel(False, True, False, False, 2)
StorageLevel.MEMORY_AND_DISK = StorageLevel(True, True, False, False)
StorageLevel.MEMORY_AND_DISK_2 = StorageLevel(True, True, False, False, 2)
-StorageLevel.OFF_HEAP = StorageLevel(False, False, True, False, 1)
+StorageLevel.OFF_HEAP = StorageLevel(True, True, True, False, 1)
"""
.. note:: The following four storage level constants are deprecated in 2.0, since the records \
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index d010c0e008..148bf7e8ff 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -1482,7 +1482,7 @@ def search_kafka_assembly_jar():
raise Exception(
("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) +
"You need to build Spark with "
- "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or "
+ "'build/sbt assembly/package streaming-kafka-assembly/assembly' or "
"'build/mvn package' before running this test.")
elif len(jars) > 1:
raise Exception(("Found multiple Spark Streaming Kafka assembly JARs: %s; please "
@@ -1548,7 +1548,7 @@ if __name__ == "__main__":
elif are_kinesis_tests_enabled is False:
sys.stderr.write("Skipping all Kinesis Python tests as the optional Kinesis project was "
"not compiled into a JAR. To run these tests, "
- "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/assembly "
+ "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/package "
"streaming-kinesis-asl-assembly/assembly' or "
"'build/mvn -Pkinesis-asl package' before running this test.")
else:
@@ -1556,7 +1556,7 @@ if __name__ == "__main__":
("Failed to find Spark Streaming Kinesis assembly jar in %s. "
% kinesis_asl_assembly_dir) +
"You need to build Spark with 'build/sbt -Pkinesis-asl "
- "assembly/assembly streaming-kinesis-asl-assembly/assembly'"
+ "assembly/package streaming-kinesis-asl-assembly/assembly'"
"or 'build/mvn -Pkinesis-asl package' before running this test.")
sys.stderr.write("Running tests: %s \n" % (str(testcases)))
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index a5a83c7e38..97ea39dde0 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -694,6 +694,21 @@ class RDDTests(ReusedPySparkTestCase):
m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
self.assertEqual(N, m)
+ def test_unpersist(self):
+ N = 1000
+ data = [[float(i) for i in range(300)] for i in range(N)]
+ bdata = self.sc.broadcast(data) # 3MB
+ bdata.unpersist()
+ m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
+ self.assertEqual(N, m)
+ bdata.destroy()
+ try:
+ self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
+ except Exception as e:
+ pass
+ else:
+ raise Exception("job should fail after destroy the broadcast")
+
def test_multiple_broadcasts(self):
N = 1 << 21
b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM
@@ -1899,6 +1914,13 @@ class ContextTests(unittest.TestCase):
with SparkContext.getOrCreate() as sc:
self.assertTrue(SparkContext.getOrCreate() is sc)
+ def test_parallelize_eager_cleanup(self):
+ with SparkContext() as sc:
+ temp_files = os.listdir(sc._temp_dir)
+ rdd = sc.parallelize([0, 1, 2])
+ post_parallalize_temp_files = os.listdir(sc._temp_dir)
+ self.assertEqual(temp_files, post_parallalize_temp_files)
+
def test_stop(self):
sc = SparkContext()
self.assertNotEqual(SparkContext._active_spark_context, None)
@@ -1966,6 +1988,18 @@ class ContextTests(unittest.TestCase):
self.assertGreater(sc.startTime, 0)
+class ConfTests(unittest.TestCase):
+ def test_memory_conf(self):
+ memoryList = ["1T", "1G", "1M", "1024K"]
+ for memory in memoryList:
+ sc = SparkContext(conf=SparkConf().set("spark.python.worker.memory", memory))
+ l = list(range(1024))
+ random.shuffle(l)
+ rdd = sc.parallelize(l, 4)
+ self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect())
+ sc.stop()
+
+
@unittest.skipIf(not _have_scipy, "SciPy not installed")
class SciPyTests(PySparkTestCase):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 42c2f8b759..cf47ab8f96 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -29,7 +29,7 @@ from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
- write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer
+ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer
from pyspark import shuffle
pickleSer = PickleSerializer()
@@ -50,6 +50,65 @@ def add_path(path):
sys.path.insert(1, path)
+def read_command(serializer, file):
+ command = serializer._read_with_length(file)
+ if isinstance(command, Broadcast):
+ command = serializer.loads(command.value)
+ return command
+
+
+def chain(f, g):
+ """chain two function together """
+ return lambda *a: g(f(*a))
+
+
+def wrap_udf(f, return_type):
+ if return_type.needConversion():
+ toInternal = return_type.toInternal
+ return lambda *a: toInternal(f(*a))
+ else:
+ return lambda *a: f(*a)
+
+
+def read_single_udf(pickleSer, infile):
+ num_arg = read_int(infile)
+ arg_offsets = [read_int(infile) for i in range(num_arg)]
+ row_func = None
+ for i in range(read_int(infile)):
+ f, return_type = read_command(pickleSer, infile)
+ if row_func is None:
+ row_func = f
+ else:
+ row_func = chain(row_func, f)
+ # the last returnType will be the return type of UDF
+ return arg_offsets, wrap_udf(row_func, return_type)
+
+
+def read_udfs(pickleSer, infile):
+ num_udfs = read_int(infile)
+ if num_udfs == 1:
+ # fast path for single UDF
+ _, udf = read_single_udf(pickleSer, infile)
+ mapper = lambda a: udf(*a)
+ else:
+ udfs = {}
+ call_udf = []
+ for i in range(num_udfs):
+ arg_offsets, udf = read_single_udf(pickleSer, infile)
+ udfs['f%d' % i] = udf
+ args = ["a[%d]" % o for o in arg_offsets]
+ call_udf.append("f%d(%s)" % (i, ", ".join(args)))
+ # Create function like this:
+ # lambda a: (f0(a0), f1(a1, a2), f2(a3))
+ mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
+ mapper = eval(mapper_str, udfs)
+
+ func = lambda _, it: map(mapper, it)
+ ser = BatchedSerializer(PickleSerializer(), 100)
+ # profiling is not supported for UDF
+ return func, None, ser, ser
+
+
def main(infile, outfile):
try:
boot_time = time.time()
@@ -95,10 +154,12 @@ def main(infile, outfile):
_broadcastRegistry.pop(bid)
_accumulatorRegistry.clear()
- command = pickleSer._read_with_length(infile)
- if isinstance(command, Broadcast):
- command = pickleSer.loads(command.value)
- func, profiler, deserializer, serializer = command
+ is_sql_udf = read_int(infile)
+ if is_sql_udf:
+ func, profiler, deserializer, serializer = read_udfs(pickleSer, infile)
+ else:
+ func, profiler, deserializer, serializer = read_command(pickleSer, infile)
+
init_time = time.time()
def process():
diff --git a/python/run-tests.py b/python/run-tests.py
index a9f8854e6f..38b3bb84c1 100755
--- a/python/run-tests.py
+++ b/python/run-tests.py
@@ -53,11 +53,25 @@ LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log")
FAILURE_REPORTING_LOCK = Lock()
LOGGER = logging.getLogger()
+# Find out where the assembly jars are located.
+for scala in ["2.11", "2.10"]:
+ build_dir = os.path.join(SPARK_HOME, "assembly", "target", "scala-" + scala)
+ if os.path.isdir(build_dir):
+ SPARK_DIST_CLASSPATH = os.path.join(build_dir, "jars", "*")
+ break
+else:
+ raise Exception("Cannot find assembly build directory, please build Spark first.")
+
def run_individual_python_test(test_name, pyspark_python):
env = dict(os.environ)
- env.update({'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python),
- 'PYSPARK_DRIVER_PYTHON': which(pyspark_python)})
+ env.update({
+ 'SPARK_DIST_CLASSPATH': SPARK_DIST_CLASSPATH,
+ 'SPARK_TESTING': '1',
+ 'SPARK_PREPEND_CLASSES': '1',
+ 'PYSPARK_PYTHON': which(pyspark_python),
+ 'PYSPARK_DRIVER_PYTHON': which(pyspark_python)
+ })
LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name)
start_time = time.time()
try:
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
index 67a616dc15..c5dc6ba221 100644
--- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
+++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
@@ -797,9 +797,11 @@ class SparkILoop(
// echo("Switched " + (if (old) "off" else "on") + " result printing.")
}
- /** Run one command submitted by the user. Two values are returned:
- * (1) whether to keep running, (2) the line to record for replay,
- * if any. */
+ /**
+ * Run one command submitted by the user. Two values are returned:
+ * (1) whether to keep running, (2) the line to record for replay,
+ * if any.
+ */
private[repl] def command(line: String): Result = {
if (line startsWith ":") {
val cmd = line.tail takeWhile (x => !x.isWhitespace)
@@ -841,12 +843,13 @@ class SparkILoop(
}
import paste.{ ContinueString, PromptString }
- /** Interpret expressions starting with the first line.
- * Read lines until a complete compilation unit is available
- * or until a syntax error has been seen. If a full unit is
- * read, go ahead and interpret it. Return the full string
- * to be recorded for replay, if any.
- */
+ /**
+ * Interpret expressions starting with the first line.
+ * Read lines until a complete compilation unit is available
+ * or until a syntax error has been seen. If a full unit is
+ * read, go ahead and interpret it. Return the full string
+ * to be recorded for replay, if any.
+ */
private def interpretStartingWith(code: String): Option[String] = {
// signal completion non-completion input has been received
in.completion.resetVerbosity()
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala
index 1d0fe10d3d..f22776592c 100644
--- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala
+++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala
@@ -118,8 +118,9 @@ private[repl] trait SparkImports {
case class ReqAndHandler(req: Request, handler: MemberHandler) { }
def reqsToUse: List[ReqAndHandler] = {
- /** Loop through a list of MemberHandlers and select which ones to keep.
- * 'wanted' is the set of names that need to be imported.
+ /**
+ * Loop through a list of MemberHandlers and select which ones to keep.
+ * 'wanted' is the set of names that need to be imported.
*/
def select(reqs: List[ReqAndHandler], wanted: Set[Name]): List[ReqAndHandler] = {
// Single symbol imports might be implicits! See bug #1752. Rather than
diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 6b9aa5071e..547da8f713 100644
--- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -285,7 +285,7 @@ class ReplSuite extends SparkFunSuite {
val output = runInterpreter("local",
"""
|import org.apache.spark.sql.functions._
- |import org.apache.spark.sql.Encoder
+ |import org.apache.spark.sql.{Encoder, Encoders}
|import org.apache.spark.sql.expressions.Aggregator
|import org.apache.spark.sql.TypedColumn
|val simpleSum = new Aggregator[Int, Int, Int] {
@@ -293,6 +293,8 @@ class ReplSuite extends SparkFunSuite {
| def reduce(b: Int, a: Int) = b + a // Add an element to the running total
| def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
| def finish(b: Int) = b // Return the final result.
+ | def bufferEncoder: Encoder[Int] = Encoders.scalaInt
+ | def outputEncoder: Encoder[Int] = Encoders.scalaInt
|}.toColumn
|
|val ds = Seq(1, 2, 3, 4).toDS()
@@ -339,30 +341,6 @@ class ReplSuite extends SparkFunSuite {
}
}
- test("Datasets agg type-inference") {
- val output = runInterpreter("local",
- """
- |import org.apache.spark.sql.functions._
- |import org.apache.spark.sql.Encoder
- |import org.apache.spark.sql.expressions.Aggregator
- |import org.apache.spark.sql.TypedColumn
- |/** An `Aggregator` that adds up any numeric type returned by the given function. */
- |class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] {
- | val numeric = implicitly[Numeric[N]]
- | override def zero: N = numeric.zero
- | override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
- | override def merge(b1: N,b2: N): N = numeric.plus(b1, b2)
- | override def finish(reduction: N): N = reduction
- |}
- |
- |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn
- |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS()
- |ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect()
- """.stripMargin)
- assertDoesNotContain("error:", output)
- assertDoesNotContain("Exception", output)
- }
-
test("collecting objects of class defined in repl") {
val output = runInterpreter("local[2]",
"""
diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
index 7ed6d3b1f9..db09d6ace1 100644
--- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
+++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
@@ -19,12 +19,11 @@ package org.apache.spark.repl
import java.io.BufferedReader
-import Predef.{println => _, _}
-import scala.util.Properties.{javaVersion, versionString, javaVmName}
-
-import scala.tools.nsc.interpreter.{JPrintWriter, ILoop}
+import scala.Predef.{println => _, _}
import scala.tools.nsc.Settings
+import scala.tools.nsc.interpreter.{ILoop, JPrintWriter}
import scala.tools.nsc.util.stringFromStream
+import scala.util.Properties.{javaVersion, javaVmName, versionString}
/**
* A Spark-specific interactive shell.
@@ -75,11 +74,9 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter)
echo("Type :help for more information.")
}
- import LoopCommand.{ cmd, nullary }
-
- private val blockedCommands = Set("implicits", "javap", "power", "type", "kind")
+ private val blockedCommands = Set("implicits", "javap", "power", "type", "kind", "reset")
- /** Standard commands **/
+ /** Standard commands */
lazy val sparkStandardCommands: List[SparkILoop.this.LoopCommand] =
standardCommands.filter(cmd => !blockedCommands(cmd.name))
@@ -112,9 +109,9 @@ object SparkILoop {
val output = new JPrintWriter(new OutputStreamWriter(ostream), true)
val repl = new SparkILoop(input, output)
- if (sets.classpath.isDefault)
+ if (sets.classpath.isDefault) {
sets.classpath.value = sys.props("java.class.path")
-
+ }
repl process sets
}
}
diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index dbfacba346..d3dafe9c42 100644
--- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -267,7 +267,7 @@ class ReplSuite extends SparkFunSuite {
val output = runInterpreter("local",
"""
|import org.apache.spark.sql.functions._
- |import org.apache.spark.sql.Encoder
+ |import org.apache.spark.sql.{Encoder, Encoders}
|import org.apache.spark.sql.expressions.Aggregator
|import org.apache.spark.sql.TypedColumn
|val simpleSum = new Aggregator[Int, Int, Int] {
@@ -275,6 +275,8 @@ class ReplSuite extends SparkFunSuite {
| def reduce(b: Int, a: Int) = b + a // Add an element to the running total
| def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
| def finish(b: Int) = b // Return the final result.
+ | def bufferEncoder: Encoder[Int] = Encoders.scalaInt
+ | def outputEncoder: Encoder[Int] = Encoders.scalaInt
|}.toColumn
|
|val ds = Seq(1, 2, 3, 4).toDS()
@@ -321,31 +323,6 @@ class ReplSuite extends SparkFunSuite {
}
}
- test("Datasets agg type-inference") {
- val output = runInterpreter("local",
- """
- |import org.apache.spark.sql.functions._
- |import org.apache.spark.sql.Encoder
- |import org.apache.spark.sql.expressions.Aggregator
- |import org.apache.spark.sql.TypedColumn
- |/** An `Aggregator` that adds up any numeric type returned by the given function. */
- |class SumOf[I, N : Numeric](f: I => N) extends
- | org.apache.spark.sql.expressions.Aggregator[I, N, N] {
- | val numeric = implicitly[Numeric[N]]
- | override def zero: N = numeric.zero
- | override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
- | override def merge(b1: N,b2: N): N = numeric.plus(b1, b2)
- | override def finish(reduction: N): N = reduction
- |}
- |
- |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn
- |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS()
- |ds.groupByKey(_._1).agg(sum(_._2), sum(_._3)).collect()
- """.stripMargin)
- assertDoesNotContain("error:", output)
- assertDoesNotContain("Exception", output)
- }
-
test("collecting objects of class defined in repl") {
val output = runInterpreter("local[2]",
"""
@@ -396,4 +373,31 @@ class ReplSuite extends SparkFunSuite {
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}
+
+ test("should clone and clean line object in ClosureCleaner") {
+ val output = runInterpreterInPasteMode("local-cluster[1,4,4096]",
+ """
+ |import org.apache.spark.rdd.RDD
+ |
+ |val lines = sc.textFile("pom.xml")
+ |case class Data(s: String)
+ |val dataRDD = lines.map(line => Data(line.take(3)))
+ |dataRDD.cache.count
+ |val repartitioned = dataRDD.repartition(dataRDD.partitions.size)
+ |repartitioned.cache.count
+ |
+ |def getCacheSize(rdd: RDD[_]) = {
+ | sc.getRDDStorageInfo.filter(_.id == rdd.id).map(_.memSize).sum
+ |}
+ |val cacheSize1 = getCacheSize(dataRDD)
+ |val cacheSize2 = getCacheSize(repartitioned)
+ |
+ |// The cache size of dataRDD and the repartitioned one should be similar.
+ |val deviation = math.abs(cacheSize2 - cacheSize1).toDouble / cacheSize1
+ |assert(deviation < 0.2,
+ | s"deviation too large: $deviation, first size: $cacheSize1, second size: $cacheSize2")
+ """.stripMargin)
+ assertDoesNotContain("AssertionError", output)
+ assertDoesNotContain("Exception", output)
+ }
}
diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
index 928aaa5629..4a15d52b57 100644
--- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala
@@ -70,26 +70,24 @@ class ExecutorClassLoader(
}
override def findClass(name: String): Class[_] = {
- userClassPathFirst match {
- case true => findClassLocally(name).getOrElse(parentLoader.loadClass(name))
- case false => {
- try {
- parentLoader.loadClass(name)
- } catch {
- case e: ClassNotFoundException => {
- val classOption = findClassLocally(name)
- classOption match {
- case None =>
- // If this class has a cause, it will break the internal assumption of Janino
- // (the compiler used for Spark SQL code-gen).
- // See org.codehaus.janino.ClassLoaderIClassLoader's findIClass, you will see
- // its behavior will be changed if there is a cause and the compilation
- // of generated class will fail.
- throw new ClassNotFoundException(name)
- case Some(a) => a
- }
+ if (userClassPathFirst) {
+ findClassLocally(name).getOrElse(parentLoader.loadClass(name))
+ } else {
+ try {
+ parentLoader.loadClass(name)
+ } catch {
+ case e: ClassNotFoundException =>
+ val classOption = findClassLocally(name)
+ classOption match {
+ case None =>
+ // If this class has a cause, it will break the internal assumption of Janino
+ // (the compiler used for Spark SQL code-gen).
+ // See org.codehaus.janino.ClassLoaderIClassLoader's findIClass, you will see
+ // its behavior will be changed if there is a cause and the compilation
+ // of generated class will fail.
+ throw new ClassNotFoundException(name)
+ case Some(a) => a
}
- }
}
}
}
diff --git a/repl/src/test/resources/log4j.properties b/repl/src/test/resources/log4j.properties
index e2ee9c963a..7665bd5e7c 100644
--- a/repl/src/test/resources/log4j.properties
+++ b/repl/src/test/resources/log4j.properties
@@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
-log4j.logger.org.spark-project.jetty=WARN
+log4j.logger.org.spark_project.jetty=WARN
diff --git a/scalastyle-config.xml b/scalastyle-config.xml
index 37d2ecf48e..a14e3e583f 100644
--- a/scalastyle-config.xml
+++ b/scalastyle-config.xml
@@ -116,7 +116,7 @@ This file is divided into 3 sections:
<check level="error" class="org.scalastyle.file.NewLineAtEofChecker" enabled="true"></check>
- <check level="error" class="org.scalastyle.scalariform.NonASCIICharacterChecker" enabled="true"></check>
+ <check customId="nonascii" level="error" class="org.scalastyle.scalariform.NonASCIICharacterChecker" enabled="true"></check>
<check level="error" class="org.scalastyle.scalariform.SpaceAfterCommentStartChecker" enabled="true"></check>
@@ -223,6 +223,16 @@ This file is divided into 3 sections:
]]></customMessage>
</check>
+ <check customId="NoScalaDoc" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
+ <parameters><parameter name="regex">(?m)^(\s*)/[*][*].*$(\r|)\n^\1 [*]</parameter></parameters>
+ <customMessage>Use Javadoc style indentation for multiline comments</customMessage>
+ </check>
+
+ <check customId="OmitBracesInCase" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
+ <parameters><parameter name="regex">case[^\n>]*=>\s*\{</parameter></parameters>
+ <customMessage>Omit braces in case clauses.</customMessage>
+ </check>
+
<!-- ================================================================================ -->
<!-- rules we'd like to enforce, but haven't cleaned up the codebase yet -->
<!-- ================================================================================ -->
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 5d1d9edd25..1748fa2778 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -73,7 +73,7 @@
</dependency>
<dependency>
<groupId>org.antlr</groupId>
- <artifactId>antlr-runtime</artifactId>
+ <artifactId>antlr4-runtime</artifactId>
</dependency>
<dependency>
<groupId>commons-codec</groupId>
@@ -113,20 +113,17 @@
</plugin>
<plugin>
<groupId>org.antlr</groupId>
- <artifactId>antlr3-maven-plugin</artifactId>
+ <artifactId>antlr4-maven-plugin</artifactId>
<executions>
<execution>
<goals>
- <goal>antlr</goal>
+ <goal>antlr4</goal>
</goals>
</execution>
</executions>
<configuration>
- <sourceDirectory>../catalyst/src/main/antlr3</sourceDirectory>
- <includes>
- <include>**/SparkSqlLexer.g</include>
- <include>**/SparkSqlParser.g</include>
- </includes>
+ <visitor>true</visitor>
+ <sourceDirectory>../catalyst/src/main/antlr4</sourceDirectory>
</configuration>
</plugin>
</plugins>
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g
deleted file mode 100644
index 13a6a2d276..0000000000
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g
+++ /dev/null
@@ -1,400 +0,0 @@
-/**
- Licensed to the Apache Software Foundation (ASF) under one or more
- contributor license agreements. See the NOTICE file distributed with
- this work for additional information regarding copyright ownership.
- The ASF licenses this file to You under the Apache License, Version 2.0
- (the "License"); you may not use this file except in compliance with
- the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-
- This file is an adaptation of Hive's org/apache/hadoop/hive/ql/IdentifiersParser.g grammar.
-*/
-
-parser grammar ExpressionParser;
-
-options
-{
-output=AST;
-ASTLabelType=CommonTree;
-backtrack=false;
-k=3;
-}
-
-@members {
- @Override
- public Object recoverFromMismatchedSet(IntStream input,
- RecognitionException re, BitSet follow) throws RecognitionException {
- throw re;
- }
- @Override
- public void displayRecognitionError(String[] tokenNames,
- RecognitionException e) {
- gParent.displayRecognitionError(tokenNames, e);
- }
- protected boolean useSQL11ReservedKeywordsForIdentifier() {
- return gParent.useSQL11ReservedKeywordsForIdentifier();
- }
-}
-
-@rulecatch {
-catch (RecognitionException e) {
- throw e;
-}
-}
-
-// fun(par1, par2, par3)
-function
-@init { gParent.pushMsg("function specification", state); }
-@after { gParent.popMsg(state); }
- :
- functionName
- LPAREN
- (
- (STAR) => (star=STAR)
- | (dist=KW_DISTINCT)? (selectExpression (COMMA selectExpression)*)?
- )
- RPAREN (KW_OVER ws=window_specification)?
- -> {$star != null}? ^(TOK_FUNCTIONSTAR functionName $ws?)
- -> {$dist == null}? ^(TOK_FUNCTION functionName (selectExpression+)? $ws?)
- -> ^(TOK_FUNCTIONDI functionName (selectExpression+)? $ws?)
- ;
-
-functionName
-@init { gParent.pushMsg("function name", state); }
-@after { gParent.popMsg(state); }
- : // Keyword IF is also a function name
- (KW_IF | KW_ARRAY | KW_MAP | KW_STRUCT | KW_UNIONTYPE) => (KW_IF | KW_ARRAY | KW_MAP | KW_STRUCT | KW_UNIONTYPE)
- |
- (functionIdentifier) => functionIdentifier
- |
- {!useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsCastFunctionName -> Identifier[$sql11ReservedKeywordsUsedAsCastFunctionName.text]
- ;
-
-castExpression
-@init { gParent.pushMsg("cast expression", state); }
-@after { gParent.popMsg(state); }
- :
- KW_CAST
- LPAREN
- expression
- KW_AS
- primitiveType
- RPAREN -> ^(TOK_FUNCTION primitiveType expression)
- ;
-
-caseExpression
-@init { gParent.pushMsg("case expression", state); }
-@after { gParent.popMsg(state); }
- :
- KW_CASE expression
- (KW_WHEN expression KW_THEN expression)+
- (KW_ELSE expression)?
- KW_END -> ^(TOK_FUNCTION KW_CASE expression*)
- ;
-
-whenExpression
-@init { gParent.pushMsg("case expression", state); }
-@after { gParent.popMsg(state); }
- :
- KW_CASE
- ( KW_WHEN expression KW_THEN expression)+
- (KW_ELSE expression)?
- KW_END -> ^(TOK_FUNCTION KW_WHEN expression*)
- ;
-
-constant
-@init { gParent.pushMsg("constant", state); }
-@after { gParent.popMsg(state); }
- :
- Number
- | dateLiteral
- | timestampLiteral
- | intervalLiteral
- | StringLiteral
- | stringLiteralSequence
- | BigintLiteral
- | SmallintLiteral
- | TinyintLiteral
- | DoubleLiteral
- | booleanValue
- ;
-
-stringLiteralSequence
- :
- StringLiteral StringLiteral+ -> ^(TOK_STRINGLITERALSEQUENCE StringLiteral StringLiteral+)
- ;
-
-dateLiteral
- :
- KW_DATE StringLiteral ->
- {
- // Create DateLiteral token, but with the text of the string value
- // This makes the dateLiteral more consistent with the other type literals.
- adaptor.create(TOK_DATELITERAL, $StringLiteral.text)
- }
- |
- KW_CURRENT_DATE -> ^(TOK_FUNCTION KW_CURRENT_DATE)
- ;
-
-timestampLiteral
- :
- KW_TIMESTAMP StringLiteral ->
- {
- adaptor.create(TOK_TIMESTAMPLITERAL, $StringLiteral.text)
- }
- |
- KW_CURRENT_TIMESTAMP -> ^(TOK_FUNCTION KW_CURRENT_TIMESTAMP)
- ;
-
-intervalLiteral
- :
- (KW_INTERVAL intervalConstant KW_YEAR KW_TO KW_MONTH) => KW_INTERVAL intervalConstant KW_YEAR KW_TO KW_MONTH
- -> ^(TOK_INTERVAL_YEAR_MONTH_LITERAL intervalConstant)
- | (KW_INTERVAL intervalConstant KW_DAY KW_TO KW_SECOND) => KW_INTERVAL intervalConstant KW_DAY KW_TO KW_SECOND
- -> ^(TOK_INTERVAL_DAY_TIME_LITERAL intervalConstant)
- | KW_INTERVAL
- ((intervalConstant KW_YEAR)=> year=intervalConstant KW_YEAR)?
- ((intervalConstant KW_MONTH)=> month=intervalConstant KW_MONTH)?
- ((intervalConstant KW_WEEK)=> week=intervalConstant KW_WEEK)?
- ((intervalConstant KW_DAY)=> day=intervalConstant KW_DAY)?
- ((intervalConstant KW_HOUR)=> hour=intervalConstant KW_HOUR)?
- ((intervalConstant KW_MINUTE)=> minute=intervalConstant KW_MINUTE)?
- ((intervalConstant KW_SECOND)=> second=intervalConstant KW_SECOND)?
- ((intervalConstant KW_MILLISECOND)=> millisecond=intervalConstant KW_MILLISECOND)?
- ((intervalConstant KW_MICROSECOND)=> microsecond=intervalConstant KW_MICROSECOND)?
- -> ^(TOK_INTERVAL
- ^(TOK_INTERVAL_YEAR_LITERAL $year?)
- ^(TOK_INTERVAL_MONTH_LITERAL $month?)
- ^(TOK_INTERVAL_WEEK_LITERAL $week?)
- ^(TOK_INTERVAL_DAY_LITERAL $day?)
- ^(TOK_INTERVAL_HOUR_LITERAL $hour?)
- ^(TOK_INTERVAL_MINUTE_LITERAL $minute?)
- ^(TOK_INTERVAL_SECOND_LITERAL $second?)
- ^(TOK_INTERVAL_MILLISECOND_LITERAL $millisecond?)
- ^(TOK_INTERVAL_MICROSECOND_LITERAL $microsecond?))
- ;
-
-intervalConstant
- :
- sign=(MINUS|PLUS)? value=Number -> {
- adaptor.create(Number, ($sign != null ? $sign.getText() : "") + $value.getText())
- }
- | StringLiteral
- ;
-
-expression
-@init { gParent.pushMsg("expression specification", state); }
-@after { gParent.popMsg(state); }
- :
- precedenceOrExpression
- ;
-
-atomExpression
- :
- (KW_NULL) => KW_NULL -> TOK_NULL
- | (constant) => constant
- | castExpression
- | caseExpression
- | whenExpression
- | (functionName LPAREN) => function
- | tableOrColumn
- | (LPAREN KW_SELECT) => subQueryExpression
- -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP) subQueryExpression)
- | LPAREN! expression RPAREN!
- ;
-
-
-precedenceFieldExpression
- :
- atomExpression ((LSQUARE^ expression RSQUARE!) | (DOT^ identifier))*
- ;
-
-precedenceUnaryOperator
- :
- PLUS | MINUS | TILDE
- ;
-
-nullCondition
- :
- KW_NULL -> ^(TOK_ISNULL)
- | KW_NOT KW_NULL -> ^(TOK_ISNOTNULL)
- ;
-
-precedenceUnaryPrefixExpression
- :
- (precedenceUnaryOperator+)=> precedenceUnaryOperator^ precedenceUnaryPrefixExpression
- | precedenceFieldExpression
- ;
-
-precedenceUnarySuffixExpression
- :
- (
- (LPAREN precedenceUnaryPrefixExpression RPAREN) => LPAREN precedenceUnaryPrefixExpression (a=KW_IS nullCondition)? RPAREN
- |
- precedenceUnaryPrefixExpression (a=KW_IS nullCondition)?
- )
- -> {$a != null}? ^(TOK_FUNCTION nullCondition precedenceUnaryPrefixExpression)
- -> precedenceUnaryPrefixExpression
- ;
-
-
-precedenceBitwiseXorOperator
- :
- BITWISEXOR
- ;
-
-precedenceBitwiseXorExpression
- :
- precedenceUnarySuffixExpression (precedenceBitwiseXorOperator^ precedenceUnarySuffixExpression)*
- ;
-
-
-precedenceStarOperator
- :
- STAR | DIVIDE | MOD | DIV
- ;
-
-precedenceStarExpression
- :
- precedenceBitwiseXorExpression (precedenceStarOperator^ precedenceBitwiseXorExpression)*
- ;
-
-
-precedencePlusOperator
- :
- PLUS | MINUS
- ;
-
-precedencePlusExpression
- :
- precedenceStarExpression (precedencePlusOperator^ precedenceStarExpression)*
- ;
-
-
-precedenceAmpersandOperator
- :
- AMPERSAND
- ;
-
-precedenceAmpersandExpression
- :
- precedencePlusExpression (precedenceAmpersandOperator^ precedencePlusExpression)*
- ;
-
-
-precedenceBitwiseOrOperator
- :
- BITWISEOR
- ;
-
-precedenceBitwiseOrExpression
- :
- precedenceAmpersandExpression (precedenceBitwiseOrOperator^ precedenceAmpersandExpression)*
- ;
-
-
-// Equal operators supporting NOT prefix
-precedenceEqualNegatableOperator
- :
- KW_LIKE | KW_RLIKE | KW_REGEXP
- ;
-
-precedenceEqualOperator
- :
- precedenceEqualNegatableOperator | EQUAL | EQUAL_NS | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN
- ;
-
-subQueryExpression
- :
- LPAREN! selectStatement[true] RPAREN!
- ;
-
-precedenceEqualExpression
- :
- (LPAREN precedenceBitwiseOrExpression COMMA) => precedenceEqualExpressionMutiple
- |
- precedenceEqualExpressionSingle
- ;
-
-precedenceEqualExpressionSingle
- :
- (left=precedenceBitwiseOrExpression -> $left)
- (
- (KW_NOT precedenceEqualNegatableOperator notExpr=precedenceBitwiseOrExpression)
- -> ^(KW_NOT ^(precedenceEqualNegatableOperator $precedenceEqualExpressionSingle $notExpr))
- | (precedenceEqualOperator equalExpr=precedenceBitwiseOrExpression)
- -> ^(precedenceEqualOperator $precedenceEqualExpressionSingle $equalExpr)
- | (KW_NOT KW_IN LPAREN KW_SELECT)=> (KW_NOT KW_IN subQueryExpression)
- -> ^(KW_NOT ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_IN) subQueryExpression $precedenceEqualExpressionSingle))
- | (KW_NOT KW_IN expressions)
- -> ^(KW_NOT ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionSingle expressions))
- | (KW_IN LPAREN KW_SELECT)=> (KW_IN subQueryExpression)
- -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_IN) subQueryExpression $precedenceEqualExpressionSingle)
- | (KW_IN expressions)
- -> ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionSingle expressions)
- | ( KW_NOT KW_BETWEEN (min=precedenceBitwiseOrExpression) KW_AND (max=precedenceBitwiseOrExpression) )
- -> ^(TOK_FUNCTION Identifier["between"] KW_TRUE $left $min $max)
- | ( KW_BETWEEN (min=precedenceBitwiseOrExpression) KW_AND (max=precedenceBitwiseOrExpression) )
- -> ^(TOK_FUNCTION Identifier["between"] KW_FALSE $left $min $max)
- )*
- | (KW_EXISTS LPAREN KW_SELECT)=> (KW_EXISTS subQueryExpression) -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_EXISTS) subQueryExpression)
- ;
-
-expressions
- :
- LPAREN expression (COMMA expression)* RPAREN -> expression+
- ;
-
-//we transform the (col0, col1) in ((v00,v01),(v10,v11)) into struct(col0, col1) in (struct(v00,v01),struct(v10,v11))
-precedenceEqualExpressionMutiple
- :
- (LPAREN precedenceBitwiseOrExpression (COMMA precedenceBitwiseOrExpression)+ RPAREN -> ^(TOK_FUNCTION Identifier["struct"] precedenceBitwiseOrExpression+))
- ( (KW_IN LPAREN expressionsToStruct (COMMA expressionsToStruct)+ RPAREN)
- -> ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionMutiple expressionsToStruct+)
- | (KW_NOT KW_IN LPAREN expressionsToStruct (COMMA expressionsToStruct)+ RPAREN)
- -> ^(KW_NOT ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionMutiple expressionsToStruct+)))
- ;
-
-expressionsToStruct
- :
- LPAREN expression (COMMA expression)* RPAREN -> ^(TOK_FUNCTION Identifier["struct"] expression+)
- ;
-
-precedenceNotOperator
- :
- KW_NOT
- ;
-
-precedenceNotExpression
- :
- (precedenceNotOperator^)* precedenceEqualExpression
- ;
-
-
-precedenceAndOperator
- :
- KW_AND
- ;
-
-precedenceAndExpression
- :
- precedenceNotExpression (precedenceAndOperator^ precedenceNotExpression)*
- ;
-
-
-precedenceOrOperator
- :
- KW_OR
- ;
-
-precedenceOrExpression
- :
- precedenceAndExpression (precedenceOrOperator^ precedenceAndExpression)*
- ;
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g
deleted file mode 100644
index 1bf461c912..0000000000
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g
+++ /dev/null
@@ -1,341 +0,0 @@
-/**
- Licensed to the Apache Software Foundation (ASF) under one or more
- contributor license agreements. See the NOTICE file distributed with
- this work for additional information regarding copyright ownership.
- The ASF licenses this file to You under the Apache License, Version 2.0
- (the "License"); you may not use this file except in compliance with
- the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-
- This file is an adaptation of Hive's org/apache/hadoop/hive/ql/FromClauseParser.g grammar.
-*/
-parser grammar FromClauseParser;
-
-options
-{
-output=AST;
-ASTLabelType=CommonTree;
-backtrack=false;
-k=3;
-}
-
-@members {
- @Override
- public Object recoverFromMismatchedSet(IntStream input,
- RecognitionException re, BitSet follow) throws RecognitionException {
- throw re;
- }
- @Override
- public void displayRecognitionError(String[] tokenNames,
- RecognitionException e) {
- gParent.displayRecognitionError(tokenNames, e);
- }
- protected boolean useSQL11ReservedKeywordsForIdentifier() {
- return gParent.useSQL11ReservedKeywordsForIdentifier();
- }
-}
-
-@rulecatch {
-catch (RecognitionException e) {
- throw e;
-}
-}
-
-//-----------------------------------------------------------------------------------
-
-tableAllColumns
- : STAR
- -> ^(TOK_ALLCOLREF)
- | tableName DOT STAR
- -> ^(TOK_ALLCOLREF tableName)
- ;
-
-// (table|column)
-tableOrColumn
-@init { gParent.pushMsg("table or column identifier", state); }
-@after { gParent.popMsg(state); }
- :
- identifier -> ^(TOK_TABLE_OR_COL identifier)
- ;
-
-expressionList
-@init { gParent.pushMsg("expression list", state); }
-@after { gParent.popMsg(state); }
- :
- expression (COMMA expression)* -> ^(TOK_EXPLIST expression+)
- ;
-
-aliasList
-@init { gParent.pushMsg("alias list", state); }
-@after { gParent.popMsg(state); }
- :
- identifier (COMMA identifier)* -> ^(TOK_ALIASLIST identifier+)
- ;
-
-//----------------------- Rules for parsing fromClause ------------------------------
-// from [col1, col2, col3] table1, [col4, col5] table2
-fromClause
-@init { gParent.pushMsg("from clause", state); }
-@after { gParent.popMsg(state); }
- :
- KW_FROM joinSource -> ^(TOK_FROM joinSource)
- ;
-
-joinSource
-@init { gParent.pushMsg("join source", state); }
-@after { gParent.popMsg(state); }
- : fromSource ( joinToken^ fromSource ( joinCond {$joinToken.start.getType() != COMMA}? )? )*
- | uniqueJoinToken^ uniqueJoinSource (COMMA! uniqueJoinSource)+
- ;
-
-joinCond
-@init { gParent.pushMsg("join expression list", state); }
-@after { gParent.popMsg(state); }
- : KW_ON! expression
- | KW_USING LPAREN columnNameList RPAREN -> ^(TOK_USING columnNameList)
- ;
-
-uniqueJoinSource
-@init { gParent.pushMsg("unique join source", state); }
-@after { gParent.popMsg(state); }
- : KW_PRESERVE? fromSource uniqueJoinExpr
- ;
-
-uniqueJoinExpr
-@init { gParent.pushMsg("unique join expression list", state); }
-@after { gParent.popMsg(state); }
- : LPAREN e1+=expression (COMMA e1+=expression)* RPAREN
- -> ^(TOK_EXPLIST $e1*)
- ;
-
-uniqueJoinToken
-@init { gParent.pushMsg("unique join", state); }
-@after { gParent.popMsg(state); }
- : KW_UNIQUEJOIN -> TOK_UNIQUEJOIN;
-
-joinToken
-@init { gParent.pushMsg("join type specifier", state); }
-@after { gParent.popMsg(state); }
- :
- KW_JOIN -> TOK_JOIN
- | KW_INNER KW_JOIN -> TOK_JOIN
- | KW_NATURAL KW_JOIN -> TOK_NATURALJOIN
- | KW_NATURAL KW_INNER KW_JOIN -> TOK_NATURALJOIN
- | COMMA -> TOK_JOIN
- | KW_CROSS KW_JOIN -> TOK_CROSSJOIN
- | KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_LEFTOUTERJOIN
- | KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_RIGHTOUTERJOIN
- | KW_FULL (KW_OUTER)? KW_JOIN -> TOK_FULLOUTERJOIN
- | KW_NATURAL KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_NATURALLEFTOUTERJOIN
- | KW_NATURAL KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_NATURALRIGHTOUTERJOIN
- | KW_NATURAL KW_FULL (KW_OUTER)? KW_JOIN -> TOK_NATURALFULLOUTERJOIN
- | KW_LEFT KW_SEMI KW_JOIN -> TOK_LEFTSEMIJOIN
- | KW_ANTI KW_JOIN -> TOK_ANTIJOIN
- ;
-
-lateralView
-@init {gParent.pushMsg("lateral view", state); }
-@after {gParent.popMsg(state); }
- :
- (KW_LATERAL KW_VIEW KW_OUTER) => KW_LATERAL KW_VIEW KW_OUTER function tableAlias (KW_AS identifier ((COMMA)=> COMMA identifier)*)?
- -> ^(TOK_LATERAL_VIEW_OUTER ^(TOK_SELECT ^(TOK_SELEXPR function identifier* tableAlias)))
- |
- KW_LATERAL KW_VIEW function tableAlias (KW_AS identifier ((COMMA)=> COMMA identifier)*)?
- -> ^(TOK_LATERAL_VIEW ^(TOK_SELECT ^(TOK_SELEXPR function identifier* tableAlias)))
- ;
-
-tableAlias
-@init {gParent.pushMsg("table alias", state); }
-@after {gParent.popMsg(state); }
- :
- identifier -> ^(TOK_TABALIAS identifier)
- ;
-
-fromSource
-@init { gParent.pushMsg("from source", state); }
-@after { gParent.popMsg(state); }
- :
- (LPAREN KW_VALUES) => fromSource0
- | fromSource0
- | (LPAREN joinSource) => LPAREN joinSource RPAREN -> joinSource
- ;
-
-
-fromSource0
-@init { gParent.pushMsg("from source 0", state); }
-@after { gParent.popMsg(state); }
- :
- ((Identifier LPAREN)=> partitionedTableFunction | tableSource | subQuerySource | virtualTableSource) (lateralView^)*
- ;
-
-tableBucketSample
-@init { gParent.pushMsg("table bucket sample specification", state); }
-@after { gParent.popMsg(state); }
- :
- KW_TABLESAMPLE LPAREN KW_BUCKET (numerator=Number) KW_OUT KW_OF (denominator=Number) (KW_ON expr+=expression (COMMA expr+=expression)*)? RPAREN -> ^(TOK_TABLEBUCKETSAMPLE $numerator $denominator $expr*)
- ;
-
-splitSample
-@init { gParent.pushMsg("table split sample specification", state); }
-@after { gParent.popMsg(state); }
- :
- KW_TABLESAMPLE LPAREN (numerator=Number) (percent=KW_PERCENT|KW_ROWS) RPAREN
- -> {percent != null}? ^(TOK_TABLESPLITSAMPLE TOK_PERCENT $numerator)
- -> ^(TOK_TABLESPLITSAMPLE TOK_ROWCOUNT $numerator)
- |
- KW_TABLESAMPLE LPAREN (numerator=ByteLengthLiteral) RPAREN
- -> ^(TOK_TABLESPLITSAMPLE TOK_LENGTH $numerator)
- ;
-
-tableSample
-@init { gParent.pushMsg("table sample specification", state); }
-@after { gParent.popMsg(state); }
- :
- tableBucketSample |
- splitSample
- ;
-
-tableSource
-@init { gParent.pushMsg("table source", state); }
-@after { gParent.popMsg(state); }
- : tabname=tableName
- ((tableProperties) => props=tableProperties)?
- ((tableSample) => ts=tableSample)?
- ((KW_AS) => (KW_AS alias=Identifier)
- |
- (Identifier) => (alias=Identifier))?
- -> ^(TOK_TABREF $tabname $props? $ts? $alias?)
- ;
-
-tableName
-@init { gParent.pushMsg("table name", state); }
-@after { gParent.popMsg(state); }
- :
- id1=identifier (DOT id2=identifier)?
- -> ^(TOK_TABNAME $id1 $id2?)
- ;
-
-viewName
-@init { gParent.pushMsg("view name", state); }
-@after { gParent.popMsg(state); }
- :
- (db=identifier DOT)? view=identifier
- -> ^(TOK_TABNAME $db? $view)
- ;
-
-subQuerySource
-@init { gParent.pushMsg("subquery source", state); }
-@after { gParent.popMsg(state); }
- :
- LPAREN queryStatementExpression[false] RPAREN KW_AS? identifier -> ^(TOK_SUBQUERY queryStatementExpression identifier)
- ;
-
-//---------------------- Rules for parsing PTF clauses -----------------------------
-partitioningSpec
-@init { gParent.pushMsg("partitioningSpec clause", state); }
-@after { gParent.popMsg(state); }
- :
- partitionByClause orderByClause? -> ^(TOK_PARTITIONINGSPEC partitionByClause orderByClause?) |
- orderByClause -> ^(TOK_PARTITIONINGSPEC orderByClause) |
- distributeByClause sortByClause? -> ^(TOK_PARTITIONINGSPEC distributeByClause sortByClause?) |
- sortByClause -> ^(TOK_PARTITIONINGSPEC sortByClause) |
- clusterByClause -> ^(TOK_PARTITIONINGSPEC clusterByClause)
- ;
-
-partitionTableFunctionSource
-@init { gParent.pushMsg("partitionTableFunctionSource clause", state); }
-@after { gParent.popMsg(state); }
- :
- subQuerySource |
- tableSource |
- partitionedTableFunction
- ;
-
-partitionedTableFunction
-@init { gParent.pushMsg("ptf clause", state); }
-@after { gParent.popMsg(state); }
- :
- name=Identifier LPAREN KW_ON
- ((partitionTableFunctionSource) => (ptfsrc=partitionTableFunctionSource spec=partitioningSpec?))
- ((Identifier LPAREN expression RPAREN ) => Identifier LPAREN expression RPAREN ( COMMA Identifier LPAREN expression RPAREN)*)?
- ((RPAREN) => (RPAREN)) ((Identifier) => alias=Identifier)?
- -> ^(TOK_PTBLFUNCTION $name $alias? $ptfsrc $spec? expression*)
- ;
-
-//----------------------- Rules for parsing whereClause -----------------------------
-// where a=b and ...
-whereClause
-@init { gParent.pushMsg("where clause", state); }
-@after { gParent.popMsg(state); }
- :
- KW_WHERE searchCondition -> ^(TOK_WHERE searchCondition)
- ;
-
-searchCondition
-@init { gParent.pushMsg("search condition", state); }
-@after { gParent.popMsg(state); }
- :
- expression
- ;
-
-//-----------------------------------------------------------------------------------
-
-//-------- Row Constructor ----------------------------------------------------------
-//in support of SELECT * FROM (VALUES(1,2,3),(4,5,6),...) as FOO(a,b,c) and
-// INSERT INTO <table> (col1,col2,...) VALUES(...),(...),...
-// INSERT INTO <table> (col1,col2,...) SELECT * FROM (VALUES(1,2,3),(4,5,6),...) as Foo(a,b,c)
-valueRowConstructor
-@init { gParent.pushMsg("value row constructor", state); }
-@after { gParent.popMsg(state); }
- :
- LPAREN precedenceUnaryPrefixExpression (COMMA precedenceUnaryPrefixExpression)* RPAREN -> ^(TOK_VALUE_ROW precedenceUnaryPrefixExpression+)
- ;
-
-valuesTableConstructor
-@init { gParent.pushMsg("values table constructor", state); }
-@after { gParent.popMsg(state); }
- :
- valueRowConstructor (COMMA valueRowConstructor)* -> ^(TOK_VALUES_TABLE valueRowConstructor+)
- ;
-
-/*
-VALUES(1),(2) means 2 rows, 1 column each.
-VALUES(1,2),(3,4) means 2 rows, 2 columns each.
-VALUES(1,2,3) means 1 row, 3 columns
-*/
-valuesClause
-@init { gParent.pushMsg("values clause", state); }
-@after { gParent.popMsg(state); }
- :
- KW_VALUES valuesTableConstructor -> valuesTableConstructor
- ;
-
-/*
-This represents a clause like this:
-(VALUES(1,2),(2,3)) as VirtTable(col1,col2)
-*/
-virtualTableSource
-@init { gParent.pushMsg("virtual table source", state); }
-@after { gParent.popMsg(state); }
- :
- LPAREN valuesClause RPAREN tableNameColList -> ^(TOK_VIRTUAL_TABLE tableNameColList valuesClause)
- ;
-/*
-e.g. as VirtTable(col1,col2)
-Note that we only want literals as column names
-*/
-tableNameColList
-@init { gParent.pushMsg("from source", state); }
-@after { gParent.popMsg(state); }
- :
- KW_AS? identifier LPAREN identifier (COMMA identifier)* RPAREN -> ^(TOK_VIRTUAL_TABREF ^(TOK_TABNAME identifier) ^(TOK_COL_NAME identifier+))
- ;
-
-//-----------------------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g
deleted file mode 100644
index 916eb6a7ac..0000000000
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g
+++ /dev/null
@@ -1,184 +0,0 @@
-/**
- Licensed to the Apache Software Foundation (ASF) under one or more
- contributor license agreements. See the NOTICE file distributed with
- this work for additional information regarding copyright ownership.
- The ASF licenses this file to You under the Apache License, Version 2.0
- (the "License"); you may not use this file except in compliance with
- the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-
- This file is an adaptation of Hive's org/apache/hadoop/hive/ql/IdentifiersParser.g grammar.
-*/
-parser grammar IdentifiersParser;
-
-options
-{
-output=AST;
-ASTLabelType=CommonTree;
-backtrack=false;
-k=3;
-}
-
-@members {
- @Override
- public Object recoverFromMismatchedSet(IntStream input,
- RecognitionException re, BitSet follow) throws RecognitionException {
- throw re;
- }
- @Override
- public void displayRecognitionError(String[] tokenNames,
- RecognitionException e) {
- gParent.displayRecognitionError(tokenNames, e);
- }
- protected boolean useSQL11ReservedKeywordsForIdentifier() {
- return gParent.useSQL11ReservedKeywordsForIdentifier();
- }
-}
-
-@rulecatch {
-catch (RecognitionException e) {
- throw e;
-}
-}
-
-//-----------------------------------------------------------------------------------
-
-// group by a,b
-groupByClause
-@init { gParent.pushMsg("group by clause", state); }
-@after { gParent.popMsg(state); }
- :
- KW_GROUP KW_BY
- expression
- ( COMMA expression)*
- ((rollup=KW_WITH KW_ROLLUP) | (cube=KW_WITH KW_CUBE)) ?
- (sets=KW_GROUPING KW_SETS
- LPAREN groupingSetExpression ( COMMA groupingSetExpression)* RPAREN ) ?
- -> {rollup != null}? ^(TOK_ROLLUP_GROUPBY expression+)
- -> {cube != null}? ^(TOK_CUBE_GROUPBY expression+)
- -> {sets != null}? ^(TOK_GROUPING_SETS expression+ groupingSetExpression+)
- -> ^(TOK_GROUPBY expression+)
- ;
-
-groupingSetExpression
-@init {gParent.pushMsg("grouping set expression", state); }
-@after {gParent.popMsg(state); }
- :
- (LPAREN) => groupingSetExpressionMultiple
- |
- groupingExpressionSingle
- ;
-
-groupingSetExpressionMultiple
-@init {gParent.pushMsg("grouping set part expression", state); }
-@after {gParent.popMsg(state); }
- :
- LPAREN
- expression? (COMMA expression)*
- RPAREN
- -> ^(TOK_GROUPING_SETS_EXPRESSION expression*)
- ;
-
-groupingExpressionSingle
-@init { gParent.pushMsg("groupingExpression expression", state); }
-@after { gParent.popMsg(state); }
- :
- expression -> ^(TOK_GROUPING_SETS_EXPRESSION expression)
- ;
-
-havingClause
-@init { gParent.pushMsg("having clause", state); }
-@after { gParent.popMsg(state); }
- :
- KW_HAVING havingCondition -> ^(TOK_HAVING havingCondition)
- ;
-
-havingCondition
-@init { gParent.pushMsg("having condition", state); }
-@after { gParent.popMsg(state); }
- :
- expression
- ;
-
-expressionsInParenthese
- :
- LPAREN expression (COMMA expression)* RPAREN -> expression+
- ;
-
-expressionsNotInParenthese
- :
- expression (COMMA expression)* -> expression+
- ;
-
-columnRefOrderInParenthese
- :
- LPAREN columnRefOrder (COMMA columnRefOrder)* RPAREN -> columnRefOrder+
- ;
-
-columnRefOrderNotInParenthese
- :
- columnRefOrder (COMMA columnRefOrder)* -> columnRefOrder+
- ;
-
-// order by a,b
-orderByClause
-@init { gParent.pushMsg("order by clause", state); }
-@after { gParent.popMsg(state); }
- :
- KW_ORDER KW_BY columnRefOrder ( COMMA columnRefOrder)* -> ^(TOK_ORDERBY columnRefOrder+)
- ;
-
-clusterByClause
-@init { gParent.pushMsg("cluster by clause", state); }
-@after { gParent.popMsg(state); }
- :
- KW_CLUSTER KW_BY
- (
- (LPAREN) => expressionsInParenthese -> ^(TOK_CLUSTERBY expressionsInParenthese)
- |
- expressionsNotInParenthese -> ^(TOK_CLUSTERBY expressionsNotInParenthese)
- )
- ;
-
-partitionByClause
-@init { gParent.pushMsg("partition by clause", state); }
-@after { gParent.popMsg(state); }
- :
- KW_PARTITION KW_BY
- (
- (LPAREN) => expressionsInParenthese -> ^(TOK_DISTRIBUTEBY expressionsInParenthese)
- |
- expressionsNotInParenthese -> ^(TOK_DISTRIBUTEBY expressionsNotInParenthese)
- )
- ;
-
-distributeByClause
-@init { gParent.pushMsg("distribute by clause", state); }
-@after { gParent.popMsg(state); }
- :
- KW_DISTRIBUTE KW_BY
- (
- (LPAREN) => expressionsInParenthese -> ^(TOK_DISTRIBUTEBY expressionsInParenthese)
- |
- expressionsNotInParenthese -> ^(TOK_DISTRIBUTEBY expressionsNotInParenthese)
- )
- ;
-
-sortByClause
-@init { gParent.pushMsg("sort by clause", state); }
-@after { gParent.popMsg(state); }
- :
- KW_SORT KW_BY
- (
- (LPAREN) => columnRefOrderInParenthese -> ^(TOK_SORTBY columnRefOrderInParenthese)
- |
- columnRefOrderNotInParenthese -> ^(TOK_SORTBY columnRefOrderNotInParenthese)
- )
- ;
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/KeywordParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/KeywordParser.g
deleted file mode 100644
index 12cd5f54a0..0000000000
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/KeywordParser.g
+++ /dev/null
@@ -1,244 +0,0 @@
-/**
- Licensed to the Apache Software Foundation (ASF) under one or more
- contributor license agreements. See the NOTICE file distributed with
- this work for additional information regarding copyright ownership.
- The ASF licenses this file to You under the Apache License, Version 2.0
- (the "License"); you may not use this file except in compliance with
- the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-
- This file is an adaptation of Hive's org/apache/hadoop/hive/ql/IdentifiersParser.g grammar.
-*/
-
-parser grammar KeywordParser;
-
-options
-{
-output=AST;
-ASTLabelType=CommonTree;
-backtrack=false;
-k=3;
-}
-
-@members {
- @Override
- public Object recoverFromMismatchedSet(IntStream input,
- RecognitionException re, BitSet follow) throws RecognitionException {
- throw re;
- }
- @Override
- public void displayRecognitionError(String[] tokenNames,
- RecognitionException e) {
- gParent.displayRecognitionError(tokenNames, e);
- }
- protected boolean useSQL11ReservedKeywordsForIdentifier() {
- return gParent.useSQL11ReservedKeywordsForIdentifier();
- }
-}
-
-@rulecatch {
-catch (RecognitionException e) {
- throw e;
-}
-}
-
-booleanValue
- :
- KW_TRUE^ | KW_FALSE^
- ;
-
-booleanValueTok
- :
- KW_TRUE -> TOK_TRUE
- | KW_FALSE -> TOK_FALSE
- ;
-
-tableOrPartition
- :
- tableName partitionSpec? -> ^(TOK_TAB tableName partitionSpec?)
- ;
-
-partitionSpec
- :
- KW_PARTITION
- LPAREN partitionVal (COMMA partitionVal )* RPAREN -> ^(TOK_PARTSPEC partitionVal +)
- ;
-
-partitionVal
- :
- identifier (EQUAL constant)? -> ^(TOK_PARTVAL identifier constant?)
- ;
-
-dropPartitionSpec
- :
- KW_PARTITION
- LPAREN dropPartitionVal (COMMA dropPartitionVal )* RPAREN -> ^(TOK_PARTSPEC dropPartitionVal +)
- ;
-
-dropPartitionVal
- :
- identifier dropPartitionOperator constant -> ^(TOK_PARTVAL identifier dropPartitionOperator constant)
- ;
-
-dropPartitionOperator
- :
- EQUAL | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN
- ;
-
-sysFuncNames
- :
- KW_AND
- | KW_OR
- | KW_NOT
- | KW_LIKE
- | KW_IF
- | KW_CASE
- | KW_WHEN
- | KW_TINYINT
- | KW_SMALLINT
- | KW_INT
- | KW_BIGINT
- | KW_FLOAT
- | KW_DOUBLE
- | KW_BOOLEAN
- | KW_STRING
- | KW_BINARY
- | KW_ARRAY
- | KW_MAP
- | KW_STRUCT
- | KW_UNIONTYPE
- | EQUAL
- | EQUAL_NS
- | NOTEQUAL
- | LESSTHANOREQUALTO
- | LESSTHAN
- | GREATERTHANOREQUALTO
- | GREATERTHAN
- | DIVIDE
- | PLUS
- | MINUS
- | STAR
- | MOD
- | DIV
- | AMPERSAND
- | TILDE
- | BITWISEOR
- | BITWISEXOR
- | KW_RLIKE
- | KW_REGEXP
- | KW_IN
- | KW_BETWEEN
- ;
-
-descFuncNames
- :
- (sysFuncNames) => sysFuncNames
- | StringLiteral
- | functionIdentifier
- ;
-
-//We are allowed to use From and To in CreateTableUsing command's options (actually seems we can use any string as the option key). But we can't simply add them into nonReserved because by doing that we mess other existing rules. So we create a looseIdentifier and looseNonReserved here.
-looseIdentifier
- :
- Identifier
- | looseNonReserved -> Identifier[$looseNonReserved.text]
- // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false,
- // the sql11keywords in existing q tests will NOT be added back.
- | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text]
- ;
-
-identifier
- :
- Identifier
- | nonReserved -> Identifier[$nonReserved.text]
- // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false,
- // the sql11keywords in existing q tests will NOT be added back.
- | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text]
- ;
-
-functionIdentifier
-@init { gParent.pushMsg("function identifier", state); }
-@after { gParent.popMsg(state); }
- :
- identifier (DOT identifier)? -> identifier+
- ;
-
-principalIdentifier
-@init { gParent.pushMsg("identifier for principal spec", state); }
-@after { gParent.popMsg(state); }
- : identifier
- | QuotedIdentifier
- ;
-
-looseNonReserved
- : nonReserved | KW_FROM | KW_TO
- ;
-
-//The new version of nonReserved + sql11ReservedKeywordsUsedAsIdentifier = old version of nonReserved
-//Non reserved keywords are basically the keywords that can be used as identifiers.
-//All the KW_* are automatically not only keywords, but also reserved keywords.
-//That means, they can NOT be used as identifiers.
-//If you would like to use them as identifiers, put them in the nonReserved list below.
-//If you are not sure, please refer to the SQL2011 column in
-//http://www.postgresql.org/docs/9.5/static/sql-keywords-appendix.html
-nonReserved
- :
- KW_ADD | KW_ADMIN | KW_AFTER | KW_ANALYZE | KW_ARCHIVE | KW_ASC | KW_BEFORE | KW_BUCKET | KW_BUCKETS
- | KW_CASCADE | KW_CHANGE | KW_CLUSTER | KW_CLUSTERED | KW_CLUSTERSTATUS | KW_COLLECTION | KW_COLUMNS
- | KW_COMMENT | KW_COMPACT | KW_COMPACTIONS | KW_COMPUTE | KW_CONCATENATE | KW_CONTINUE | KW_DATA | KW_DAY
- | KW_DATABASES | KW_DATETIME | KW_DBPROPERTIES | KW_DEFERRED | KW_DEFINED | KW_DELIMITED | KW_DEPENDENCY
- | KW_DESC | KW_DIRECTORIES | KW_DIRECTORY | KW_DISABLE | KW_DISTRIBUTE | KW_ELEM_TYPE
- | KW_ENABLE | KW_ESCAPED | KW_EXCLUSIVE | KW_EXPLAIN | KW_EXPORT | KW_FIELDS | KW_FILE | KW_FILEFORMAT
- | KW_FIRST | KW_FORMAT | KW_FORMATTED | KW_FUNCTIONS | KW_HOLD_DDLTIME | KW_HOUR | KW_IDXPROPERTIES | KW_IGNORE
- | KW_INDEX | KW_INDEXES | KW_INPATH | KW_INPUTDRIVER | KW_INPUTFORMAT | KW_ITEMS | KW_JAR
- | KW_KEYS | KW_KEY_TYPE | KW_LIMIT | KW_LINES | KW_LOAD | KW_LOCATION | KW_LOCK | KW_LOCKS | KW_LOGICAL | KW_LONG
- | KW_MAPJOIN | KW_MATERIALIZED | KW_METADATA | KW_MINUS | KW_MINUTE | KW_MONTH | KW_MSCK | KW_NOSCAN | KW_NO_DROP | KW_OFFLINE
- | KW_OPTION | KW_OUTPUTDRIVER | KW_OUTPUTFORMAT | KW_OVERWRITE | KW_OWNER | KW_PARTITIONED | KW_PARTITIONS | KW_PLUS | KW_PRETTY
- | KW_PRINCIPALS | KW_PROTECTION | KW_PURGE | KW_READ | KW_READONLY | KW_REBUILD | KW_RECORDREADER | KW_RECORDWRITER
- | KW_RELOAD | KW_RENAME | KW_REPAIR | KW_REPLACE | KW_REPLICATION | KW_RESTRICT | KW_REWRITE
- | KW_ROLE | KW_ROLES | KW_SCHEMA | KW_SCHEMAS | KW_SECOND | KW_SEMI | KW_SERDE | KW_SERDEPROPERTIES | KW_SERVER | KW_SETS | KW_SHARED
- | KW_SHOW | KW_SHOW_DATABASE | KW_SKEWED | KW_SORT | KW_SORTED | KW_SSL | KW_STATISTICS | KW_STORED
- | KW_STREAMTABLE | KW_STRING | KW_STRUCT | KW_TABLES | KW_TBLPROPERTIES | KW_TEMPORARY | KW_TERMINATED
- | KW_TINYINT | KW_TOUCH | KW_TRANSACTIONS | KW_UNARCHIVE | KW_UNDO | KW_UNIONTYPE | KW_UNLOCK | KW_UNSET
- | KW_UNSIGNED | KW_URI | KW_USE | KW_UTC | KW_UTCTIMESTAMP | KW_VALUE_TYPE | KW_VIEW | KW_WHILE | KW_YEAR
- | KW_WORK
- | KW_TRANSACTION
- | KW_WRITE
- | KW_ISOLATION
- | KW_LEVEL
- | KW_SNAPSHOT
- | KW_AUTOCOMMIT
- | KW_ANTI
- | KW_WEEK | KW_MILLISECOND | KW_MICROSECOND
- | KW_CLEAR | KW_LAZY | KW_CACHE | KW_UNCACHE | KW_DFS
-;
-
-//The following SQL2011 reserved keywords are used as cast function name only, but not as identifiers.
-sql11ReservedKeywordsUsedAsCastFunctionName
- :
- KW_BIGINT | KW_BINARY | KW_BOOLEAN | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_DATE | KW_DOUBLE | KW_FLOAT | KW_INT | KW_SMALLINT | KW_TIMESTAMP
- ;
-
-//The following SQL2011 reserved keywords are used as identifiers in many q tests, they may be added back due to backward compatibility.
-//We are planning to remove the following whole list after several releases.
-//Thus, please do not change the following list unless you know what to do.
-sql11ReservedKeywordsUsedAsIdentifier
- :
- KW_ALL | KW_ALTER | KW_ARRAY | KW_AS | KW_AUTHORIZATION | KW_BETWEEN | KW_BIGINT | KW_BINARY | KW_BOOLEAN
- | KW_BOTH | KW_BY | KW_CREATE | KW_CUBE | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_CURSOR | KW_DATE | KW_DECIMAL | KW_DELETE | KW_DESCRIBE
- | KW_DOUBLE | KW_DROP | KW_EXISTS | KW_EXTERNAL | KW_FALSE | KW_FETCH | KW_FLOAT | KW_FOR | KW_FULL | KW_GRANT
- | KW_GROUP | KW_GROUPING | KW_IMPORT | KW_IN | KW_INNER | KW_INSERT | KW_INT | KW_INTERSECT | KW_INTO | KW_IS | KW_LATERAL
- | KW_LEFT | KW_LIKE | KW_LOCAL | KW_NONE | KW_NULL | KW_OF | KW_ORDER | KW_OUT | KW_OUTER | KW_PARTITION
- | KW_PERCENT | KW_PROCEDURE | KW_RANGE | KW_READS | KW_REVOKE | KW_RIGHT
- | KW_ROLLUP | KW_ROW | KW_ROWS | KW_SET | KW_SMALLINT | KW_TABLE | KW_TIMESTAMP | KW_TO | KW_TRIGGER | KW_TRUE
- | KW_TRUNCATE | KW_UNION | KW_UPDATE | KW_USER | KW_USING | KW_VALUES | KW_WITH
-//The following two keywords come from MySQL. Although they are not keywords in SQL2011, they are reserved keywords in MySQL.
- | KW_REGEXP | KW_RLIKE
- ;
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g
deleted file mode 100644
index f18b6ec496..0000000000
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g
+++ /dev/null
@@ -1,235 +0,0 @@
-/**
- Licensed to the Apache Software Foundation (ASF) under one or more
- contributor license agreements. See the NOTICE file distributed with
- this work for additional information regarding copyright ownership.
- The ASF licenses this file to You under the Apache License, Version 2.0
- (the "License"); you may not use this file except in compliance with
- the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-
- This file is an adaptation of Hive's org/apache/hadoop/hive/ql/SelectClauseParser.g grammar.
-*/
-parser grammar SelectClauseParser;
-
-options
-{
-output=AST;
-ASTLabelType=CommonTree;
-backtrack=false;
-k=3;
-}
-
-@members {
- @Override
- public Object recoverFromMismatchedSet(IntStream input,
- RecognitionException re, BitSet follow) throws RecognitionException {
- throw re;
- }
- @Override
- public void displayRecognitionError(String[] tokenNames,
- RecognitionException e) {
- gParent.displayRecognitionError(tokenNames, e);
- }
- protected boolean useSQL11ReservedKeywordsForIdentifier() {
- return gParent.useSQL11ReservedKeywordsForIdentifier();
- }
-}
-
-@rulecatch {
-catch (RecognitionException e) {
- throw e;
-}
-}
-
-//----------------------- Rules for parsing selectClause -----------------------------
-// select a,b,c ...
-selectClause
-@init { gParent.pushMsg("select clause", state); }
-@after { gParent.popMsg(state); }
- :
- KW_SELECT hintClause? (((KW_ALL | dist=KW_DISTINCT)? selectList)
- | (transform=KW_TRANSFORM selectTrfmClause))
- -> {$transform == null && $dist == null}? ^(TOK_SELECT hintClause? selectList)
- -> {$transform == null && $dist != null}? ^(TOK_SELECTDI hintClause? selectList)
- -> ^(TOK_SELECT hintClause? ^(TOK_SELEXPR selectTrfmClause) )
- |
- trfmClause ->^(TOK_SELECT ^(TOK_SELEXPR trfmClause))
- ;
-
-selectList
-@init { gParent.pushMsg("select list", state); }
-@after { gParent.popMsg(state); }
- :
- selectItem ( COMMA selectItem )* -> selectItem+
- ;
-
-selectTrfmClause
-@init { gParent.pushMsg("transform clause", state); }
-@after { gParent.popMsg(state); }
- :
- LPAREN selectExpressionList RPAREN
- inSerde=rowFormat inRec=recordWriter
- KW_USING StringLiteral
- ( KW_AS ((LPAREN (aliasList | columnNameTypeList) RPAREN) | (aliasList | columnNameTypeList)))?
- outSerde=rowFormat outRec=recordReader
- -> ^(TOK_TRANSFORM selectExpressionList $inSerde $inRec StringLiteral $outSerde $outRec aliasList? columnNameTypeList?)
- ;
-
-hintClause
-@init { gParent.pushMsg("hint clause", state); }
-@after { gParent.popMsg(state); }
- :
- DIVIDE STAR PLUS hintList STAR DIVIDE -> ^(TOK_HINTLIST hintList)
- ;
-
-hintList
-@init { gParent.pushMsg("hint list", state); }
-@after { gParent.popMsg(state); }
- :
- hintItem (COMMA hintItem)* -> hintItem+
- ;
-
-hintItem
-@init { gParent.pushMsg("hint item", state); }
-@after { gParent.popMsg(state); }
- :
- hintName (LPAREN hintArgs RPAREN)? -> ^(TOK_HINT hintName hintArgs?)
- ;
-
-hintName
-@init { gParent.pushMsg("hint name", state); }
-@after { gParent.popMsg(state); }
- :
- KW_MAPJOIN -> TOK_MAPJOIN
- | KW_STREAMTABLE -> TOK_STREAMTABLE
- ;
-
-hintArgs
-@init { gParent.pushMsg("hint arguments", state); }
-@after { gParent.popMsg(state); }
- :
- hintArgName (COMMA hintArgName)* -> ^(TOK_HINTARGLIST hintArgName+)
- ;
-
-hintArgName
-@init { gParent.pushMsg("hint argument name", state); }
-@after { gParent.popMsg(state); }
- :
- identifier
- ;
-
-selectItem
-@init { gParent.pushMsg("selection target", state); }
-@after { gParent.popMsg(state); }
- :
- (tableAllColumns) => tableAllColumns -> ^(TOK_SELEXPR tableAllColumns)
- |
- namedExpression
- ;
-
-namedExpression
-@init { gParent.pushMsg("select named expression", state); }
-@after { gParent.popMsg(state); }
- :
- ( expression
- ((KW_AS? identifier) | (KW_AS LPAREN identifier (COMMA identifier)* RPAREN))?
- ) -> ^(TOK_SELEXPR expression identifier*)
- ;
-
-trfmClause
-@init { gParent.pushMsg("transform clause", state); }
-@after { gParent.popMsg(state); }
- :
- ( KW_MAP selectExpressionList
- | KW_REDUCE selectExpressionList )
- inSerde=rowFormat inRec=recordWriter
- KW_USING StringLiteral
- ( KW_AS ((LPAREN (aliasList | columnNameTypeList) RPAREN) | (aliasList | columnNameTypeList)))?
- outSerde=rowFormat outRec=recordReader
- -> ^(TOK_TRANSFORM selectExpressionList $inSerde $inRec StringLiteral $outSerde $outRec aliasList? columnNameTypeList?)
- ;
-
-selectExpression
-@init { gParent.pushMsg("select expression", state); }
-@after { gParent.popMsg(state); }
- :
- (tableAllColumns) => tableAllColumns
- |
- expression
- ;
-
-selectExpressionList
-@init { gParent.pushMsg("select expression list", state); }
-@after { gParent.popMsg(state); }
- :
- selectExpression (COMMA selectExpression)* -> ^(TOK_EXPLIST selectExpression+)
- ;
-
-//---------------------- Rules for windowing clauses -------------------------------
-window_clause
-@init { gParent.pushMsg("window_clause", state); }
-@after { gParent.popMsg(state); }
-:
- KW_WINDOW window_defn (COMMA window_defn)* -> ^(KW_WINDOW window_defn+)
-;
-
-window_defn
-@init { gParent.pushMsg("window_defn", state); }
-@after { gParent.popMsg(state); }
-:
- Identifier KW_AS window_specification -> ^(TOK_WINDOWDEF Identifier window_specification)
-;
-
-window_specification
-@init { gParent.pushMsg("window_specification", state); }
-@after { gParent.popMsg(state); }
-:
- (Identifier | ( LPAREN Identifier? partitioningSpec? window_frame? RPAREN)) -> ^(TOK_WINDOWSPEC Identifier? partitioningSpec? window_frame?)
-;
-
-window_frame :
- window_range_expression |
- window_value_expression
-;
-
-window_range_expression
-@init { gParent.pushMsg("window_range_expression", state); }
-@after { gParent.popMsg(state); }
-:
- KW_ROWS sb=window_frame_start_boundary -> ^(TOK_WINDOWRANGE $sb) |
- KW_ROWS KW_BETWEEN s=window_frame_boundary KW_AND end=window_frame_boundary -> ^(TOK_WINDOWRANGE $s $end)
-;
-
-window_value_expression
-@init { gParent.pushMsg("window_value_expression", state); }
-@after { gParent.popMsg(state); }
-:
- KW_RANGE sb=window_frame_start_boundary -> ^(TOK_WINDOWVALUES $sb) |
- KW_RANGE KW_BETWEEN s=window_frame_boundary KW_AND end=window_frame_boundary -> ^(TOK_WINDOWVALUES $s $end)
-;
-
-window_frame_start_boundary
-@init { gParent.pushMsg("windowframestartboundary", state); }
-@after { gParent.popMsg(state); }
-:
- KW_UNBOUNDED KW_PRECEDING -> ^(KW_PRECEDING KW_UNBOUNDED) |
- KW_CURRENT KW_ROW -> ^(KW_CURRENT) |
- Number KW_PRECEDING -> ^(KW_PRECEDING Number)
-;
-
-window_frame_boundary
-@init { gParent.pushMsg("windowframeboundary", state); }
-@after { gParent.popMsg(state); }
-:
- KW_UNBOUNDED (r=KW_PRECEDING|r=KW_FOLLOWING) -> ^($r KW_UNBOUNDED) |
- KW_CURRENT KW_ROW -> ^(KW_CURRENT) |
- Number (d=KW_PRECEDING | d=KW_FOLLOWING ) -> ^($d Number)
-;
-
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g
deleted file mode 100644
index fd1ad59207..0000000000
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g
+++ /dev/null
@@ -1,491 +0,0 @@
-/**
- Licensed to the Apache Software Foundation (ASF) under one or more
- contributor license agreements. See the NOTICE file distributed with
- this work for additional information regarding copyright ownership.
- The ASF licenses this file to You under the Apache License, Version 2.0
- (the "License"); you may not use this file except in compliance with
- the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-
- This file is an adaptation of Hive's org/apache/hadoop/hive/ql/HiveLexer.g grammar.
-*/
-lexer grammar SparkSqlLexer;
-
-@lexer::header {
-package org.apache.spark.sql.catalyst.parser;
-
-}
-
-@lexer::members {
- private ParserConf parserConf;
- private ParseErrorReporter reporter;
-
- public void configure(ParserConf parserConf, ParseErrorReporter reporter) {
- this.parserConf = parserConf;
- this.reporter = reporter;
- }
-
- protected boolean allowQuotedId() {
- if (parserConf == null) {
- return true;
- }
- return parserConf.supportQuotedId();
- }
-
- @Override
- public void displayRecognitionError(String[] tokenNames, RecognitionException e) {
- if (reporter != null) {
- reporter.report(this, e, tokenNames);
- }
- }
-}
-
-// Keywords
-
-KW_TRUE : 'TRUE';
-KW_FALSE : 'FALSE';
-KW_ALL : 'ALL';
-KW_NONE: 'NONE';
-KW_AND : 'AND';
-KW_OR : 'OR';
-KW_NOT : 'NOT' | '!';
-KW_LIKE : 'LIKE';
-
-KW_IF : 'IF';
-KW_EXISTS : 'EXISTS';
-
-KW_ASC : 'ASC';
-KW_DESC : 'DESC';
-KW_ORDER : 'ORDER';
-KW_GROUP : 'GROUP';
-KW_BY : 'BY';
-KW_HAVING : 'HAVING';
-KW_WHERE : 'WHERE';
-KW_FROM : 'FROM';
-KW_AS : 'AS';
-KW_SELECT : 'SELECT';
-KW_DISTINCT : 'DISTINCT';
-KW_INSERT : 'INSERT';
-KW_OVERWRITE : 'OVERWRITE';
-KW_OUTER : 'OUTER';
-KW_UNIQUEJOIN : 'UNIQUEJOIN';
-KW_PRESERVE : 'PRESERVE';
-KW_JOIN : 'JOIN';
-KW_LEFT : 'LEFT';
-KW_RIGHT : 'RIGHT';
-KW_FULL : 'FULL';
-KW_ANTI : 'ANTI';
-KW_ON : 'ON';
-KW_PARTITION : 'PARTITION';
-KW_PARTITIONS : 'PARTITIONS';
-KW_TABLE: 'TABLE';
-KW_TABLES: 'TABLES';
-KW_COLUMNS: 'COLUMNS';
-KW_INDEX: 'INDEX';
-KW_INDEXES: 'INDEXES';
-KW_REBUILD: 'REBUILD';
-KW_FUNCTIONS: 'FUNCTIONS';
-KW_SHOW: 'SHOW';
-KW_MSCK: 'MSCK';
-KW_REPAIR: 'REPAIR';
-KW_DIRECTORY: 'DIRECTORY';
-KW_LOCAL: 'LOCAL';
-KW_TRANSFORM : 'TRANSFORM';
-KW_USING: 'USING';
-KW_CLUSTER: 'CLUSTER';
-KW_DISTRIBUTE: 'DISTRIBUTE';
-KW_SORT: 'SORT';
-KW_UNION: 'UNION';
-KW_EXCEPT: 'EXCEPT';
-KW_LOAD: 'LOAD';
-KW_EXPORT: 'EXPORT';
-KW_IMPORT: 'IMPORT';
-KW_REPLICATION: 'REPLICATION';
-KW_METADATA: 'METADATA';
-KW_DATA: 'DATA';
-KW_INPATH: 'INPATH';
-KW_IS: 'IS';
-KW_NULL: 'NULL';
-KW_CREATE: 'CREATE';
-KW_EXTERNAL: 'EXTERNAL';
-KW_ALTER: 'ALTER';
-KW_CHANGE: 'CHANGE';
-KW_COLUMN: 'COLUMN';
-KW_FIRST: 'FIRST';
-KW_AFTER: 'AFTER';
-KW_DESCRIBE: 'DESCRIBE';
-KW_DROP: 'DROP';
-KW_RENAME: 'RENAME';
-KW_TO: 'TO';
-KW_COMMENT: 'COMMENT';
-KW_BOOLEAN: 'BOOLEAN';
-KW_TINYINT: 'TINYINT';
-KW_SMALLINT: 'SMALLINT';
-KW_INT: 'INT';
-KW_BIGINT: 'BIGINT';
-KW_FLOAT: 'FLOAT';
-KW_DOUBLE: 'DOUBLE';
-KW_DATE: 'DATE';
-KW_DATETIME: 'DATETIME';
-KW_TIMESTAMP: 'TIMESTAMP';
-KW_INTERVAL: 'INTERVAL';
-KW_DECIMAL: 'DECIMAL';
-KW_STRING: 'STRING';
-KW_CHAR: 'CHAR';
-KW_VARCHAR: 'VARCHAR';
-KW_ARRAY: 'ARRAY';
-KW_STRUCT: 'STRUCT';
-KW_MAP: 'MAP';
-KW_UNIONTYPE: 'UNIONTYPE';
-KW_REDUCE: 'REDUCE';
-KW_PARTITIONED: 'PARTITIONED';
-KW_CLUSTERED: 'CLUSTERED';
-KW_SORTED: 'SORTED';
-KW_INTO: 'INTO';
-KW_BUCKETS: 'BUCKETS';
-KW_ROW: 'ROW';
-KW_ROWS: 'ROWS';
-KW_FORMAT: 'FORMAT';
-KW_DELIMITED: 'DELIMITED';
-KW_FIELDS: 'FIELDS';
-KW_TERMINATED: 'TERMINATED';
-KW_ESCAPED: 'ESCAPED';
-KW_COLLECTION: 'COLLECTION';
-KW_ITEMS: 'ITEMS';
-KW_KEYS: 'KEYS';
-KW_KEY_TYPE: '$KEY$';
-KW_LINES: 'LINES';
-KW_STORED: 'STORED';
-KW_FILEFORMAT: 'FILEFORMAT';
-KW_INPUTFORMAT: 'INPUTFORMAT';
-KW_OUTPUTFORMAT: 'OUTPUTFORMAT';
-KW_INPUTDRIVER: 'INPUTDRIVER';
-KW_OUTPUTDRIVER: 'OUTPUTDRIVER';
-KW_ENABLE: 'ENABLE';
-KW_DISABLE: 'DISABLE';
-KW_LOCATION: 'LOCATION';
-KW_TABLESAMPLE: 'TABLESAMPLE';
-KW_BUCKET: 'BUCKET';
-KW_OUT: 'OUT';
-KW_OF: 'OF';
-KW_PERCENT: 'PERCENT';
-KW_CAST: 'CAST';
-KW_ADD: 'ADD';
-KW_REPLACE: 'REPLACE';
-KW_RLIKE: 'RLIKE';
-KW_REGEXP: 'REGEXP';
-KW_TEMPORARY: 'TEMPORARY';
-KW_FUNCTION: 'FUNCTION';
-KW_MACRO: 'MACRO';
-KW_FILE: 'FILE';
-KW_JAR: 'JAR';
-KW_EXPLAIN: 'EXPLAIN';
-KW_EXTENDED: 'EXTENDED';
-KW_FORMATTED: 'FORMATTED';
-KW_PRETTY: 'PRETTY';
-KW_DEPENDENCY: 'DEPENDENCY';
-KW_LOGICAL: 'LOGICAL';
-KW_SERDE: 'SERDE';
-KW_WITH: 'WITH';
-KW_DEFERRED: 'DEFERRED';
-KW_SERDEPROPERTIES: 'SERDEPROPERTIES';
-KW_DBPROPERTIES: 'DBPROPERTIES';
-KW_LIMIT: 'LIMIT';
-KW_SET: 'SET';
-KW_UNSET: 'UNSET';
-KW_TBLPROPERTIES: 'TBLPROPERTIES';
-KW_IDXPROPERTIES: 'IDXPROPERTIES';
-KW_VALUE_TYPE: '$VALUE$';
-KW_ELEM_TYPE: '$ELEM$';
-KW_DEFINED: 'DEFINED';
-KW_CASE: 'CASE';
-KW_WHEN: 'WHEN';
-KW_THEN: 'THEN';
-KW_ELSE: 'ELSE';
-KW_END: 'END';
-KW_MAPJOIN: 'MAPJOIN';
-KW_STREAMTABLE: 'STREAMTABLE';
-KW_CLUSTERSTATUS: 'CLUSTERSTATUS';
-KW_UTC: 'UTC';
-KW_UTCTIMESTAMP: 'UTC_TMESTAMP';
-KW_LONG: 'LONG';
-KW_DELETE: 'DELETE';
-KW_PLUS: 'PLUS';
-KW_MINUS: 'MINUS';
-KW_FETCH: 'FETCH';
-KW_INTERSECT: 'INTERSECT';
-KW_VIEW: 'VIEW';
-KW_IN: 'IN';
-KW_DATABASE: 'DATABASE';
-KW_DATABASES: 'DATABASES';
-KW_MATERIALIZED: 'MATERIALIZED';
-KW_SCHEMA: 'SCHEMA';
-KW_SCHEMAS: 'SCHEMAS';
-KW_GRANT: 'GRANT';
-KW_REVOKE: 'REVOKE';
-KW_SSL: 'SSL';
-KW_UNDO: 'UNDO';
-KW_LOCK: 'LOCK';
-KW_LOCKS: 'LOCKS';
-KW_UNLOCK: 'UNLOCK';
-KW_SHARED: 'SHARED';
-KW_EXCLUSIVE: 'EXCLUSIVE';
-KW_PROCEDURE: 'PROCEDURE';
-KW_UNSIGNED: 'UNSIGNED';
-KW_WHILE: 'WHILE';
-KW_READ: 'READ';
-KW_READS: 'READS';
-KW_PURGE: 'PURGE';
-KW_RANGE: 'RANGE';
-KW_ANALYZE: 'ANALYZE';
-KW_BEFORE: 'BEFORE';
-KW_BETWEEN: 'BETWEEN';
-KW_BOTH: 'BOTH';
-KW_BINARY: 'BINARY';
-KW_CROSS: 'CROSS';
-KW_CONTINUE: 'CONTINUE';
-KW_CURSOR: 'CURSOR';
-KW_TRIGGER: 'TRIGGER';
-KW_RECORDREADER: 'RECORDREADER';
-KW_RECORDWRITER: 'RECORDWRITER';
-KW_SEMI: 'SEMI';
-KW_LATERAL: 'LATERAL';
-KW_TOUCH: 'TOUCH';
-KW_ARCHIVE: 'ARCHIVE';
-KW_UNARCHIVE: 'UNARCHIVE';
-KW_COMPUTE: 'COMPUTE';
-KW_STATISTICS: 'STATISTICS';
-KW_USE: 'USE';
-KW_OPTION: 'OPTION';
-KW_CONCATENATE: 'CONCATENATE';
-KW_SHOW_DATABASE: 'SHOW_DATABASE';
-KW_UPDATE: 'UPDATE';
-KW_RESTRICT: 'RESTRICT';
-KW_CASCADE: 'CASCADE';
-KW_SKEWED: 'SKEWED';
-KW_ROLLUP: 'ROLLUP';
-KW_CUBE: 'CUBE';
-KW_DIRECTORIES: 'DIRECTORIES';
-KW_FOR: 'FOR';
-KW_WINDOW: 'WINDOW';
-KW_UNBOUNDED: 'UNBOUNDED';
-KW_PRECEDING: 'PRECEDING';
-KW_FOLLOWING: 'FOLLOWING';
-KW_CURRENT: 'CURRENT';
-KW_CURRENT_DATE: 'CURRENT_DATE';
-KW_CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP';
-KW_LESS: 'LESS';
-KW_MORE: 'MORE';
-KW_OVER: 'OVER';
-KW_GROUPING: 'GROUPING';
-KW_SETS: 'SETS';
-KW_TRUNCATE: 'TRUNCATE';
-KW_NOSCAN: 'NOSCAN';
-KW_PARTIALSCAN: 'PARTIALSCAN';
-KW_USER: 'USER';
-KW_ROLE: 'ROLE';
-KW_ROLES: 'ROLES';
-KW_INNER: 'INNER';
-KW_EXCHANGE: 'EXCHANGE';
-KW_URI: 'URI';
-KW_SERVER : 'SERVER';
-KW_ADMIN: 'ADMIN';
-KW_OWNER: 'OWNER';
-KW_PRINCIPALS: 'PRINCIPALS';
-KW_COMPACT: 'COMPACT';
-KW_COMPACTIONS: 'COMPACTIONS';
-KW_TRANSACTIONS: 'TRANSACTIONS';
-KW_REWRITE : 'REWRITE';
-KW_AUTHORIZATION: 'AUTHORIZATION';
-KW_CONF: 'CONF';
-KW_VALUES: 'VALUES';
-KW_RELOAD: 'RELOAD';
-KW_YEAR: 'YEAR'|'YEARS';
-KW_MONTH: 'MONTH'|'MONTHS';
-KW_DAY: 'DAY'|'DAYS';
-KW_HOUR: 'HOUR'|'HOURS';
-KW_MINUTE: 'MINUTE'|'MINUTES';
-KW_SECOND: 'SECOND'|'SECONDS';
-KW_START: 'START';
-KW_TRANSACTION: 'TRANSACTION';
-KW_COMMIT: 'COMMIT';
-KW_ROLLBACK: 'ROLLBACK';
-KW_WORK: 'WORK';
-KW_ONLY: 'ONLY';
-KW_WRITE: 'WRITE';
-KW_ISOLATION: 'ISOLATION';
-KW_LEVEL: 'LEVEL';
-KW_SNAPSHOT: 'SNAPSHOT';
-KW_AUTOCOMMIT: 'AUTOCOMMIT';
-KW_REFRESH: 'REFRESH';
-KW_OPTIONS: 'OPTIONS';
-KW_WEEK: 'WEEK'|'WEEKS';
-KW_MILLISECOND: 'MILLISECOND'|'MILLISECONDS';
-KW_MICROSECOND: 'MICROSECOND'|'MICROSECONDS';
-KW_CLEAR: 'CLEAR';
-KW_LAZY: 'LAZY';
-KW_CACHE: 'CACHE';
-KW_UNCACHE: 'UNCACHE';
-KW_DFS: 'DFS';
-
-KW_NATURAL: 'NATURAL';
-
-// Operators
-// NOTE: if you add a new function/operator, add it to sysFuncNames so that describe function _FUNC_ will work.
-
-DOT : '.'; // generated as a part of Number rule
-COLON : ':' ;
-COMMA : ',' ;
-SEMICOLON : ';' ;
-
-LPAREN : '(' ;
-RPAREN : ')' ;
-LSQUARE : '[' ;
-RSQUARE : ']' ;
-LCURLY : '{';
-RCURLY : '}';
-
-EQUAL : '=' | '==';
-EQUAL_NS : '<=>';
-NOTEQUAL : '<>' | '!=';
-LESSTHANOREQUALTO : '<=';
-LESSTHAN : '<';
-GREATERTHANOREQUALTO : '>=';
-GREATERTHAN : '>';
-
-DIVIDE : '/';
-PLUS : '+';
-MINUS : '-';
-STAR : '*';
-MOD : '%';
-DIV : 'DIV';
-
-AMPERSAND : '&';
-TILDE : '~';
-BITWISEOR : '|';
-BITWISEXOR : '^';
-QUESTION : '?';
-DOLLAR : '$';
-
-// LITERALS
-fragment
-Letter
- : 'a'..'z' | 'A'..'Z'
- ;
-
-fragment
-HexDigit
- : 'a'..'f' | 'A'..'F'
- ;
-
-fragment
-Digit
- :
- '0'..'9'
- ;
-
-fragment
-Exponent
- :
- ('e' | 'E') ( PLUS|MINUS )? (Digit)+
- ;
-
-fragment
-RegexComponent
- : 'a'..'z' | 'A'..'Z' | '0'..'9' | '_'
- | PLUS | STAR | QUESTION | MINUS | DOT
- | LPAREN | RPAREN | LSQUARE | RSQUARE | LCURLY | RCURLY
- | BITWISEXOR | BITWISEOR | DOLLAR | '!'
- ;
-
-StringLiteral
- :
- ( '\'' ( ~('\''|'\\') | ('\\' .) )* '\''
- | '\"' ( ~('\"'|'\\') | ('\\' .) )* '\"'
- )+
- ;
-
-BigintLiteral
- :
- (Digit)+ 'L'
- ;
-
-SmallintLiteral
- :
- (Digit)+ 'S'
- ;
-
-TinyintLiteral
- :
- (Digit)+ 'Y'
- ;
-
-DoubleLiteral
- :
- Number 'D'
- ;
-
-ByteLengthLiteral
- :
- (Digit)+ ('b' | 'B' | 'k' | 'K' | 'm' | 'M' | 'g' | 'G')
- ;
-
-Number
- :
- ((Digit+ (DOT Digit*)?) | (DOT Digit+)) Exponent?
- ;
-
-/*
-An Identifier can be:
-- tableName
-- columnName
-- select expr alias
-- lateral view aliases
-- database name
-- view name
-- subquery alias
-- function name
-- ptf argument identifier
-- index name
-- property name for: db,tbl,partition...
-- fileFormat
-- role name
-- privilege name
-- principal name
-- macro name
-- hint name
-- window name
-*/
-Identifier
- :
- (Letter | Digit | '_')+
- | {allowQuotedId()}? QuotedIdentifier /* though at the language level we allow all Identifiers to be QuotedIdentifiers;
- at the API level only columns are allowed to be of this form */
- | '`' RegexComponent+ '`'
- ;
-
-fragment
-QuotedIdentifier
- :
- '`' ( '``' | ~('`') )* '`' { setText(getText().replaceAll("``", "`")); }
- ;
-
-WS : (' '|'\r'|'\t'|'\n') {$channel=HIDDEN;}
- ;
-
-COMMENT
- : '--' (~('\n'|'\r'))*
- { $channel=HIDDEN; }
- ;
-
-/* Prevent that the lexer swallows unknown characters. */
-ANY
- :.
- ;
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
deleted file mode 100644
index f0c236859d..0000000000
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
+++ /dev/null
@@ -1,2596 +0,0 @@
-/**
- Licensed to the Apache Software Foundation (ASF) under one or more
- contributor license agreements. See the NOTICE file distributed with
- this work for additional information regarding copyright ownership.
- The ASF licenses this file to You under the Apache License, Version 2.0
- (the "License"); you may not use this file except in compliance with
- the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-
- This file is an adaptation of Hive's org/apache/hadoop/hive/ql/HiveParser.g grammar.
-*/
-parser grammar SparkSqlParser;
-
-options
-{
-tokenVocab=SparkSqlLexer;
-output=AST;
-ASTLabelType=CommonTree;
-backtrack=false;
-k=3;
-}
-import SelectClauseParser, FromClauseParser, IdentifiersParser, KeywordParser, ExpressionParser;
-
-tokens {
-TOK_INSERT;
-TOK_QUERY;
-TOK_SELECT;
-TOK_SELECTDI;
-TOK_SELEXPR;
-TOK_FROM;
-TOK_TAB;
-TOK_PARTSPEC;
-TOK_PARTVAL;
-TOK_DIR;
-TOK_TABREF;
-TOK_SUBQUERY;
-TOK_INSERT_INTO;
-TOK_DESTINATION;
-TOK_ALLCOLREF;
-TOK_TABLE_OR_COL;
-TOK_FUNCTION;
-TOK_FUNCTIONDI;
-TOK_FUNCTIONSTAR;
-TOK_WHERE;
-TOK_OP_EQ;
-TOK_OP_NE;
-TOK_OP_LE;
-TOK_OP_LT;
-TOK_OP_GE;
-TOK_OP_GT;
-TOK_OP_DIV;
-TOK_OP_ADD;
-TOK_OP_SUB;
-TOK_OP_MUL;
-TOK_OP_MOD;
-TOK_OP_BITAND;
-TOK_OP_BITNOT;
-TOK_OP_BITOR;
-TOK_OP_BITXOR;
-TOK_OP_AND;
-TOK_OP_OR;
-TOK_OP_NOT;
-TOK_OP_LIKE;
-TOK_TRUE;
-TOK_FALSE;
-TOK_TRANSFORM;
-TOK_SERDE;
-TOK_SERDENAME;
-TOK_SERDEPROPS;
-TOK_EXPLIST;
-TOK_ALIASLIST;
-TOK_GROUPBY;
-TOK_ROLLUP_GROUPBY;
-TOK_CUBE_GROUPBY;
-TOK_GROUPING_SETS;
-TOK_GROUPING_SETS_EXPRESSION;
-TOK_HAVING;
-TOK_ORDERBY;
-TOK_CLUSTERBY;
-TOK_DISTRIBUTEBY;
-TOK_SORTBY;
-TOK_UNIONALL;
-TOK_UNIONDISTINCT;
-TOK_EXCEPT;
-TOK_INTERSECT;
-TOK_JOIN;
-TOK_LEFTOUTERJOIN;
-TOK_RIGHTOUTERJOIN;
-TOK_FULLOUTERJOIN;
-TOK_UNIQUEJOIN;
-TOK_CROSSJOIN;
-TOK_NATURALJOIN;
-TOK_NATURALLEFTOUTERJOIN;
-TOK_NATURALRIGHTOUTERJOIN;
-TOK_NATURALFULLOUTERJOIN;
-TOK_LOAD;
-TOK_EXPORT;
-TOK_IMPORT;
-TOK_REPLICATION;
-TOK_METADATA;
-TOK_NULL;
-TOK_ISNULL;
-TOK_ISNOTNULL;
-TOK_TINYINT;
-TOK_SMALLINT;
-TOK_INT;
-TOK_BIGINT;
-TOK_BOOLEAN;
-TOK_FLOAT;
-TOK_DOUBLE;
-TOK_DATE;
-TOK_DATELITERAL;
-TOK_DATETIME;
-TOK_TIMESTAMP;
-TOK_TIMESTAMPLITERAL;
-TOK_INTERVAL;
-TOK_INTERVAL_YEAR_MONTH;
-TOK_INTERVAL_YEAR_MONTH_LITERAL;
-TOK_INTERVAL_DAY_TIME;
-TOK_INTERVAL_DAY_TIME_LITERAL;
-TOK_INTERVAL_YEAR_LITERAL;
-TOK_INTERVAL_MONTH_LITERAL;
-TOK_INTERVAL_WEEK_LITERAL;
-TOK_INTERVAL_DAY_LITERAL;
-TOK_INTERVAL_HOUR_LITERAL;
-TOK_INTERVAL_MINUTE_LITERAL;
-TOK_INTERVAL_SECOND_LITERAL;
-TOK_INTERVAL_MILLISECOND_LITERAL;
-TOK_INTERVAL_MICROSECOND_LITERAL;
-TOK_STRING;
-TOK_CHAR;
-TOK_VARCHAR;
-TOK_BINARY;
-TOK_DECIMAL;
-TOK_LIST;
-TOK_STRUCT;
-TOK_MAP;
-TOK_UNIONTYPE;
-TOK_COLTYPELIST;
-TOK_CREATEDATABASE;
-TOK_CREATETABLE;
-TOK_CREATETABLEUSING;
-TOK_TRUNCATETABLE;
-TOK_CREATEINDEX;
-TOK_CREATEINDEX_INDEXTBLNAME;
-TOK_DEFERRED_REBUILDINDEX;
-TOK_DROPINDEX;
-TOK_LIKETABLE;
-TOK_DESCTABLE;
-TOK_DESCFUNCTION;
-TOK_ALTERTABLE;
-TOK_ALTERTABLE_RENAME;
-TOK_ALTERTABLE_ADDCOLS;
-TOK_ALTERTABLE_RENAMECOL;
-TOK_ALTERTABLE_RENAMEPART;
-TOK_ALTERTABLE_REPLACECOLS;
-TOK_ALTERTABLE_ADDPARTS;
-TOK_ALTERTABLE_DROPPARTS;
-TOK_ALTERTABLE_PARTCOLTYPE;
-TOK_ALTERTABLE_MERGEFILES;
-TOK_ALTERTABLE_TOUCH;
-TOK_ALTERTABLE_ARCHIVE;
-TOK_ALTERTABLE_UNARCHIVE;
-TOK_ALTERTABLE_SERDEPROPERTIES;
-TOK_ALTERTABLE_SERIALIZER;
-TOK_ALTERTABLE_UPDATECOLSTATS;
-TOK_TABLE_PARTITION;
-TOK_ALTERTABLE_FILEFORMAT;
-TOK_ALTERTABLE_LOCATION;
-TOK_ALTERTABLE_PROPERTIES;
-TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION;
-TOK_ALTERTABLE_DROPPROPERTIES;
-TOK_ALTERTABLE_SKEWED;
-TOK_ALTERTABLE_EXCHANGEPARTITION;
-TOK_ALTERTABLE_SKEWED_LOCATION;
-TOK_ALTERTABLE_BUCKETS;
-TOK_ALTERTABLE_CLUSTER_SORT;
-TOK_ALTERTABLE_COMPACT;
-TOK_ALTERINDEX_REBUILD;
-TOK_ALTERINDEX_PROPERTIES;
-TOK_MSCK;
-TOK_SHOWDATABASES;
-TOK_SHOWTABLES;
-TOK_SHOWCOLUMNS;
-TOK_SHOWFUNCTIONS;
-TOK_SHOWPARTITIONS;
-TOK_SHOW_CREATEDATABASE;
-TOK_SHOW_CREATETABLE;
-TOK_SHOW_TABLESTATUS;
-TOK_SHOW_TBLPROPERTIES;
-TOK_SHOWLOCKS;
-TOK_SHOWCONF;
-TOK_LOCKTABLE;
-TOK_UNLOCKTABLE;
-TOK_LOCKDB;
-TOK_UNLOCKDB;
-TOK_SWITCHDATABASE;
-TOK_DROPDATABASE;
-TOK_DROPTABLE;
-TOK_DATABASECOMMENT;
-TOK_TABCOLLIST;
-TOK_TABCOL;
-TOK_TABLECOMMENT;
-TOK_TABLEPARTCOLS;
-TOK_TABLEROWFORMAT;
-TOK_TABLEROWFORMATFIELD;
-TOK_TABLEROWFORMATCOLLITEMS;
-TOK_TABLEROWFORMATMAPKEYS;
-TOK_TABLEROWFORMATLINES;
-TOK_TABLEROWFORMATNULL;
-TOK_TABLEFILEFORMAT;
-TOK_FILEFORMAT_GENERIC;
-TOK_OFFLINE;
-TOK_ENABLE;
-TOK_DISABLE;
-TOK_READONLY;
-TOK_NO_DROP;
-TOK_STORAGEHANDLER;
-TOK_NOT_CLUSTERED;
-TOK_NOT_SORTED;
-TOK_TABCOLNAME;
-TOK_TABLELOCATION;
-TOK_PARTITIONLOCATION;
-TOK_TABLEBUCKETSAMPLE;
-TOK_TABLESPLITSAMPLE;
-TOK_PERCENT;
-TOK_LENGTH;
-TOK_ROWCOUNT;
-TOK_TMP_FILE;
-TOK_TABSORTCOLNAMEASC;
-TOK_TABSORTCOLNAMEDESC;
-TOK_STRINGLITERALSEQUENCE;
-TOK_CREATEFUNCTION;
-TOK_DROPFUNCTION;
-TOK_RELOADFUNCTION;
-TOK_CREATEMACRO;
-TOK_DROPMACRO;
-TOK_TEMPORARY;
-TOK_CREATEVIEW;
-TOK_DROPVIEW;
-TOK_ALTERVIEW;
-TOK_ALTERVIEW_PROPERTIES;
-TOK_ALTERVIEW_DROPPROPERTIES;
-TOK_ALTERVIEW_ADDPARTS;
-TOK_ALTERVIEW_DROPPARTS;
-TOK_ALTERVIEW_RENAME;
-TOK_VIEWPARTCOLS;
-TOK_EXPLAIN;
-TOK_EXPLAIN_SQ_REWRITE;
-TOK_TABLESERIALIZER;
-TOK_TABLEPROPERTIES;
-TOK_TABLEPROPLIST;
-TOK_INDEXPROPERTIES;
-TOK_INDEXPROPLIST;
-TOK_TABTYPE;
-TOK_LIMIT;
-TOK_TABLEPROPERTY;
-TOK_IFEXISTS;
-TOK_IFNOTEXISTS;
-TOK_ORREPLACE;
-TOK_HINTLIST;
-TOK_HINT;
-TOK_MAPJOIN;
-TOK_STREAMTABLE;
-TOK_HINTARGLIST;
-TOK_USERSCRIPTCOLNAMES;
-TOK_USERSCRIPTCOLSCHEMA;
-TOK_RECORDREADER;
-TOK_RECORDWRITER;
-TOK_LEFTSEMIJOIN;
-TOK_ANTIJOIN;
-TOK_LATERAL_VIEW;
-TOK_LATERAL_VIEW_OUTER;
-TOK_TABALIAS;
-TOK_ANALYZE;
-TOK_CREATEROLE;
-TOK_DROPROLE;
-TOK_GRANT;
-TOK_REVOKE;
-TOK_SHOW_GRANT;
-TOK_PRIVILEGE_LIST;
-TOK_PRIVILEGE;
-TOK_PRINCIPAL_NAME;
-TOK_USER;
-TOK_GROUP;
-TOK_ROLE;
-TOK_RESOURCE_ALL;
-TOK_GRANT_WITH_OPTION;
-TOK_GRANT_WITH_ADMIN_OPTION;
-TOK_ADMIN_OPTION_FOR;
-TOK_GRANT_OPTION_FOR;
-TOK_PRIV_ALL;
-TOK_PRIV_ALTER_METADATA;
-TOK_PRIV_ALTER_DATA;
-TOK_PRIV_DELETE;
-TOK_PRIV_DROP;
-TOK_PRIV_INDEX;
-TOK_PRIV_INSERT;
-TOK_PRIV_LOCK;
-TOK_PRIV_SELECT;
-TOK_PRIV_SHOW_DATABASE;
-TOK_PRIV_CREATE;
-TOK_PRIV_OBJECT;
-TOK_PRIV_OBJECT_COL;
-TOK_GRANT_ROLE;
-TOK_REVOKE_ROLE;
-TOK_SHOW_ROLE_GRANT;
-TOK_SHOW_ROLES;
-TOK_SHOW_SET_ROLE;
-TOK_SHOW_ROLE_PRINCIPALS;
-TOK_SHOWINDEXES;
-TOK_SHOWDBLOCKS;
-TOK_INDEXCOMMENT;
-TOK_DESCDATABASE;
-TOK_DATABASEPROPERTIES;
-TOK_DATABASELOCATION;
-TOK_DBPROPLIST;
-TOK_ALTERDATABASE_PROPERTIES;
-TOK_ALTERDATABASE_OWNER;
-TOK_TABNAME;
-TOK_TABSRC;
-TOK_RESTRICT;
-TOK_CASCADE;
-TOK_TABLESKEWED;
-TOK_TABCOLVALUE;
-TOK_TABCOLVALUE_PAIR;
-TOK_TABCOLVALUES;
-TOK_SKEWED_LOCATIONS;
-TOK_SKEWED_LOCATION_LIST;
-TOK_SKEWED_LOCATION_MAP;
-TOK_STOREDASDIRS;
-TOK_PARTITIONINGSPEC;
-TOK_PTBLFUNCTION;
-TOK_WINDOWDEF;
-TOK_WINDOWSPEC;
-TOK_WINDOWVALUES;
-TOK_WINDOWRANGE;
-TOK_SUBQUERY_EXPR;
-TOK_SUBQUERY_OP;
-TOK_SUBQUERY_OP_NOTIN;
-TOK_SUBQUERY_OP_NOTEXISTS;
-TOK_DB_TYPE;
-TOK_TABLE_TYPE;
-TOK_CTE;
-TOK_ARCHIVE;
-TOK_FILE;
-TOK_JAR;
-TOK_RESOURCE_URI;
-TOK_RESOURCE_LIST;
-TOK_SHOW_COMPACTIONS;
-TOK_SHOW_TRANSACTIONS;
-TOK_DELETE_FROM;
-TOK_UPDATE_TABLE;
-TOK_SET_COLUMNS_CLAUSE;
-TOK_VALUE_ROW;
-TOK_VALUES_TABLE;
-TOK_VIRTUAL_TABLE;
-TOK_VIRTUAL_TABREF;
-TOK_ANONYMOUS;
-TOK_COL_NAME;
-TOK_URI_TYPE;
-TOK_SERVER_TYPE;
-TOK_START_TRANSACTION;
-TOK_ISOLATION_LEVEL;
-TOK_ISOLATION_SNAPSHOT;
-TOK_TXN_ACCESS_MODE;
-TOK_TXN_READ_ONLY;
-TOK_TXN_READ_WRITE;
-TOK_COMMIT;
-TOK_ROLLBACK;
-TOK_SET_AUTOCOMMIT;
-TOK_REFRESHTABLE;
-TOK_TABLEPROVIDER;
-TOK_TABLEOPTIONS;
-TOK_TABLEOPTION;
-TOK_CACHETABLE;
-TOK_UNCACHETABLE;
-TOK_CLEARCACHE;
-TOK_SETCONFIG;
-TOK_DFS;
-TOK_ADDFILE;
-TOK_ADDJAR;
-TOK_USING;
-}
-
-
-// Package headers
-@header {
-package org.apache.spark.sql.catalyst.parser;
-
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.HashMap;
-}
-
-
-@members {
- Stack msgs = new Stack<String>();
-
- private static HashMap<String, String> xlateMap;
- static {
- //this is used to support auto completion in CLI
- xlateMap = new HashMap<String, String>();
-
- // Keywords
- xlateMap.put("KW_TRUE", "TRUE");
- xlateMap.put("KW_FALSE", "FALSE");
- xlateMap.put("KW_ALL", "ALL");
- xlateMap.put("KW_NONE", "NONE");
- xlateMap.put("KW_AND", "AND");
- xlateMap.put("KW_OR", "OR");
- xlateMap.put("KW_NOT", "NOT");
- xlateMap.put("KW_LIKE", "LIKE");
-
- xlateMap.put("KW_ASC", "ASC");
- xlateMap.put("KW_DESC", "DESC");
- xlateMap.put("KW_ORDER", "ORDER");
- xlateMap.put("KW_BY", "BY");
- xlateMap.put("KW_GROUP", "GROUP");
- xlateMap.put("KW_WHERE", "WHERE");
- xlateMap.put("KW_FROM", "FROM");
- xlateMap.put("KW_AS", "AS");
- xlateMap.put("KW_SELECT", "SELECT");
- xlateMap.put("KW_DISTINCT", "DISTINCT");
- xlateMap.put("KW_INSERT", "INSERT");
- xlateMap.put("KW_OVERWRITE", "OVERWRITE");
- xlateMap.put("KW_OUTER", "OUTER");
- xlateMap.put("KW_JOIN", "JOIN");
- xlateMap.put("KW_LEFT", "LEFT");
- xlateMap.put("KW_RIGHT", "RIGHT");
- xlateMap.put("KW_FULL", "FULL");
- xlateMap.put("KW_ON", "ON");
- xlateMap.put("KW_PARTITION", "PARTITION");
- xlateMap.put("KW_PARTITIONS", "PARTITIONS");
- xlateMap.put("KW_TABLE", "TABLE");
- xlateMap.put("KW_TABLES", "TABLES");
- xlateMap.put("KW_TBLPROPERTIES", "TBLPROPERTIES");
- xlateMap.put("KW_SHOW", "SHOW");
- xlateMap.put("KW_MSCK", "MSCK");
- xlateMap.put("KW_DIRECTORY", "DIRECTORY");
- xlateMap.put("KW_LOCAL", "LOCAL");
- xlateMap.put("KW_TRANSFORM", "TRANSFORM");
- xlateMap.put("KW_USING", "USING");
- xlateMap.put("KW_CLUSTER", "CLUSTER");
- xlateMap.put("KW_DISTRIBUTE", "DISTRIBUTE");
- xlateMap.put("KW_SORT", "SORT");
- xlateMap.put("KW_UNION", "UNION");
- xlateMap.put("KW_LOAD", "LOAD");
- xlateMap.put("KW_DATA", "DATA");
- xlateMap.put("KW_INPATH", "INPATH");
- xlateMap.put("KW_IS", "IS");
- xlateMap.put("KW_NULL", "NULL");
- xlateMap.put("KW_CREATE", "CREATE");
- xlateMap.put("KW_EXTERNAL", "EXTERNAL");
- xlateMap.put("KW_ALTER", "ALTER");
- xlateMap.put("KW_DESCRIBE", "DESCRIBE");
- xlateMap.put("KW_DROP", "DROP");
- xlateMap.put("KW_RENAME", "RENAME");
- xlateMap.put("KW_TO", "TO");
- xlateMap.put("KW_COMMENT", "COMMENT");
- xlateMap.put("KW_BOOLEAN", "BOOLEAN");
- xlateMap.put("KW_TINYINT", "TINYINT");
- xlateMap.put("KW_SMALLINT", "SMALLINT");
- xlateMap.put("KW_INT", "INT");
- xlateMap.put("KW_BIGINT", "BIGINT");
- xlateMap.put("KW_FLOAT", "FLOAT");
- xlateMap.put("KW_DOUBLE", "DOUBLE");
- xlateMap.put("KW_DATE", "DATE");
- xlateMap.put("KW_DATETIME", "DATETIME");
- xlateMap.put("KW_TIMESTAMP", "TIMESTAMP");
- xlateMap.put("KW_STRING", "STRING");
- xlateMap.put("KW_BINARY", "BINARY");
- xlateMap.put("KW_ARRAY", "ARRAY");
- xlateMap.put("KW_MAP", "MAP");
- xlateMap.put("KW_REDUCE", "REDUCE");
- xlateMap.put("KW_PARTITIONED", "PARTITIONED");
- xlateMap.put("KW_CLUSTERED", "CLUSTERED");
- xlateMap.put("KW_SORTED", "SORTED");
- xlateMap.put("KW_INTO", "INTO");
- xlateMap.put("KW_BUCKETS", "BUCKETS");
- xlateMap.put("KW_ROW", "ROW");
- xlateMap.put("KW_FORMAT", "FORMAT");
- xlateMap.put("KW_DELIMITED", "DELIMITED");
- xlateMap.put("KW_FIELDS", "FIELDS");
- xlateMap.put("KW_TERMINATED", "TERMINATED");
- xlateMap.put("KW_COLLECTION", "COLLECTION");
- xlateMap.put("KW_ITEMS", "ITEMS");
- xlateMap.put("KW_KEYS", "KEYS");
- xlateMap.put("KW_KEY_TYPE", "\$KEY\$");
- xlateMap.put("KW_LINES", "LINES");
- xlateMap.put("KW_STORED", "STORED");
- xlateMap.put("KW_SEQUENCEFILE", "SEQUENCEFILE");
- xlateMap.put("KW_TEXTFILE", "TEXTFILE");
- xlateMap.put("KW_INPUTFORMAT", "INPUTFORMAT");
- xlateMap.put("KW_OUTPUTFORMAT", "OUTPUTFORMAT");
- xlateMap.put("KW_LOCATION", "LOCATION");
- xlateMap.put("KW_TABLESAMPLE", "TABLESAMPLE");
- xlateMap.put("KW_BUCKET", "BUCKET");
- xlateMap.put("KW_OUT", "OUT");
- xlateMap.put("KW_OF", "OF");
- xlateMap.put("KW_CAST", "CAST");
- xlateMap.put("KW_ADD", "ADD");
- xlateMap.put("KW_REPLACE", "REPLACE");
- xlateMap.put("KW_COLUMNS", "COLUMNS");
- xlateMap.put("KW_RLIKE", "RLIKE");
- xlateMap.put("KW_REGEXP", "REGEXP");
- xlateMap.put("KW_TEMPORARY", "TEMPORARY");
- xlateMap.put("KW_FUNCTION", "FUNCTION");
- xlateMap.put("KW_EXPLAIN", "EXPLAIN");
- xlateMap.put("KW_EXTENDED", "EXTENDED");
- xlateMap.put("KW_SERDE", "SERDE");
- xlateMap.put("KW_WITH", "WITH");
- xlateMap.put("KW_SERDEPROPERTIES", "SERDEPROPERTIES");
- xlateMap.put("KW_LIMIT", "LIMIT");
- xlateMap.put("KW_SET", "SET");
- xlateMap.put("KW_PROPERTIES", "TBLPROPERTIES");
- xlateMap.put("KW_VALUE_TYPE", "\$VALUE\$");
- xlateMap.put("KW_ELEM_TYPE", "\$ELEM\$");
- xlateMap.put("KW_DEFINED", "DEFINED");
- xlateMap.put("KW_SUBQUERY", "SUBQUERY");
- xlateMap.put("KW_REWRITE", "REWRITE");
- xlateMap.put("KW_UPDATE", "UPDATE");
- xlateMap.put("KW_VALUES", "VALUES");
- xlateMap.put("KW_PURGE", "PURGE");
- xlateMap.put("KW_WEEK", "WEEK");
- xlateMap.put("KW_MILLISECOND", "MILLISECOND");
- xlateMap.put("KW_MICROSECOND", "MICROSECOND");
- xlateMap.put("KW_CLEAR", "CLEAR");
- xlateMap.put("KW_LAZY", "LAZY");
- xlateMap.put("KW_CACHE", "CACHE");
- xlateMap.put("KW_UNCACHE", "UNCACHE");
- xlateMap.put("KW_DFS", "DFS");
-
- // Operators
- xlateMap.put("DOT", ".");
- xlateMap.put("COLON", ":");
- xlateMap.put("COMMA", ",");
- xlateMap.put("SEMICOLON", ");");
-
- xlateMap.put("LPAREN", "(");
- xlateMap.put("RPAREN", ")");
- xlateMap.put("LSQUARE", "[");
- xlateMap.put("RSQUARE", "]");
-
- xlateMap.put("EQUAL", "=");
- xlateMap.put("NOTEQUAL", "<>");
- xlateMap.put("EQUAL_NS", "<=>");
- xlateMap.put("LESSTHANOREQUALTO", "<=");
- xlateMap.put("LESSTHAN", "<");
- xlateMap.put("GREATERTHANOREQUALTO", ">=");
- xlateMap.put("GREATERTHAN", ">");
-
- xlateMap.put("DIVIDE", "/");
- xlateMap.put("PLUS", "+");
- xlateMap.put("MINUS", "-");
- xlateMap.put("STAR", "*");
- xlateMap.put("MOD", "\%");
-
- xlateMap.put("AMPERSAND", "&");
- xlateMap.put("TILDE", "~");
- xlateMap.put("BITWISEOR", "|");
- xlateMap.put("BITWISEXOR", "^");
- xlateMap.put("CharSetLiteral", "\\'");
- }
-
- public static Collection<String> getKeywords() {
- return xlateMap.values();
- }
-
- private static String xlate(String name) {
-
- String ret = xlateMap.get(name);
- if (ret == null) {
- ret = name;
- }
-
- return ret;
- }
-
- @Override
- public Object recoverFromMismatchedSet(IntStream input,
- RecognitionException re, BitSet follow) throws RecognitionException {
- throw re;
- }
-
- @Override
- public void displayRecognitionError(String[] tokenNames, RecognitionException e) {
- if (reporter != null) {
- reporter.report(this, e, tokenNames);
- }
- }
-
- @Override
- public String getErrorHeader(RecognitionException e) {
- String header = null;
- if (e.charPositionInLine < 0 && input.LT(-1) != null) {
- Token t = input.LT(-1);
- header = "line " + t.getLine() + ":" + t.getCharPositionInLine();
- } else {
- header = super.getErrorHeader(e);
- }
-
- return header;
- }
-
- @Override
- public String getErrorMessage(RecognitionException e, String[] tokenNames) {
- String msg = null;
-
- // Translate the token names to something that the user can understand
- String[] xlateNames = new String[tokenNames.length];
- for (int i = 0; i < tokenNames.length; ++i) {
- xlateNames[i] = SparkSqlParser.xlate(tokenNames[i]);
- }
-
- if (e instanceof NoViableAltException) {
- @SuppressWarnings("unused")
- NoViableAltException nvae = (NoViableAltException) e;
- // for development, can add
- // "decision=<<"+nvae.grammarDecisionDescription+">>"
- // and "(decision="+nvae.decisionNumber+") and
- // "state "+nvae.stateNumber
- msg = "cannot recognize input near"
- + (input.LT(1) != null ? " " + getTokenErrorDisplay(input.LT(1)) : "")
- + (input.LT(2) != null ? " " + getTokenErrorDisplay(input.LT(2)) : "")
- + (input.LT(3) != null ? " " + getTokenErrorDisplay(input.LT(3)) : "");
- } else if (e instanceof MismatchedTokenException) {
- MismatchedTokenException mte = (MismatchedTokenException) e;
- msg = super.getErrorMessage(e, xlateNames) + (input.LT(-1) == null ? "":" near '" + input.LT(-1).getText()) + "'";
- } else if (e instanceof FailedPredicateException) {
- FailedPredicateException fpe = (FailedPredicateException) e;
- msg = "Failed to recognize predicate '" + fpe.token.getText() + "'. Failed rule: '" + fpe.ruleName + "'";
- } else {
- msg = super.getErrorMessage(e, xlateNames);
- }
-
- if (msgs.size() > 0) {
- msg = msg + " in " + msgs.peek();
- }
- return msg;
- }
-
- public void pushMsg(String msg, RecognizerSharedState state) {
- // ANTLR generated code does not wrap the @init code wit this backtracking check,
- // even if the matching @after has it. If we have parser rules with that are doing
- // some lookahead with syntactic predicates this can cause the push() and pop() calls
- // to become unbalanced, so make sure both push/pop check the backtracking state.
- if (state.backtracking == 0) {
- msgs.push(msg);
- }
- }
-
- public void popMsg(RecognizerSharedState state) {
- if (state.backtracking == 0) {
- Object o = msgs.pop();
- }
- }
-
- // counter to generate unique union aliases
- private int aliasCounter;
- private String generateUnionAlias() {
- return "u_" + (++aliasCounter);
- }
- private char [] excludedCharForColumnName = {'.', ':'};
- private boolean containExcludedCharForCreateTableColumnName(String input) {
- if (input.length() > 0) {
- if (input.charAt(0) == '`' && input.charAt(input.length() - 1) == '`') {
- // When column name is backquoted, we don't care about excluded chars.
- return false;
- }
- }
- for(char c : excludedCharForColumnName) {
- if(input.indexOf(c)>-1) {
- return true;
- }
- }
- return false;
- }
- private CommonTree throwSetOpException() throws RecognitionException {
- throw new FailedPredicateException(input, "orderByClause clusterByClause distributeByClause sortByClause limitClause can only be applied to the whole union.", "");
- }
- private CommonTree throwColumnNameException() throws RecognitionException {
- throw new FailedPredicateException(input, Arrays.toString(excludedCharForColumnName) + " can not be used in column name in create table statement.", "");
- }
-
- private ParserConf parserConf;
- private ParseErrorReporter reporter;
-
- public void configure(ParserConf parserConf, ParseErrorReporter reporter) {
- this.parserConf = parserConf;
- this.reporter = reporter;
- }
-
- protected boolean useSQL11ReservedKeywordsForIdentifier() {
- if (parserConf == null) {
- return true;
- }
- return !parserConf.supportSQL11ReservedKeywords();
- }
-}
-
-@rulecatch {
-catch (RecognitionException e) {
- reportError(e);
- throw e;
-}
-}
-
-// starting rule
-statement
- : explainStatement EOF
- | execStatement EOF
- | KW_ADD KW_JAR -> ^(TOK_ADDJAR)
- | KW_ADD KW_FILE -> ^(TOK_ADDFILE)
- | KW_DFS -> ^(TOK_DFS)
- | (KW_SET)=> KW_SET -> ^(TOK_SETCONFIG)
- ;
-
-// Rule for expression parsing
-singleNamedExpression
- :
- namedExpression EOF
- ;
-
-// Rule for table name parsing
-singleTableName
- :
- tableName EOF
- ;
-
-explainStatement
-@init { pushMsg("explain statement", state); }
-@after { popMsg(state); }
- : KW_EXPLAIN (
- explainOption* execStatement -> ^(TOK_EXPLAIN execStatement explainOption*)
- |
- KW_REWRITE queryStatementExpression[true] -> ^(TOK_EXPLAIN_SQ_REWRITE queryStatementExpression))
- ;
-
-explainOption
-@init { msgs.push("explain option"); }
-@after { msgs.pop(); }
- : KW_EXTENDED|KW_FORMATTED|KW_DEPENDENCY|KW_LOGICAL|KW_AUTHORIZATION
- ;
-
-execStatement
-@init { pushMsg("statement", state); }
-@after { popMsg(state); }
- : queryStatementExpression[true]
- | loadStatement
- | exportStatement
- | importStatement
- | ddlStatement
- | deleteStatement
- | updateStatement
- | sqlTransactionStatement
- | cacheStatement
- ;
-
-loadStatement
-@init { pushMsg("load statement", state); }
-@after { popMsg(state); }
- : KW_LOAD KW_DATA (islocal=KW_LOCAL)? KW_INPATH (path=StringLiteral) (isoverwrite=KW_OVERWRITE)? KW_INTO KW_TABLE (tab=tableOrPartition)
- -> ^(TOK_LOAD $path $tab $islocal? $isoverwrite?)
- ;
-
-replicationClause
-@init { pushMsg("replication clause", state); }
-@after { popMsg(state); }
- : KW_FOR (isMetadataOnly=KW_METADATA)? KW_REPLICATION LPAREN (replId=StringLiteral) RPAREN
- -> ^(TOK_REPLICATION $replId $isMetadataOnly?)
- ;
-
-exportStatement
-@init { pushMsg("export statement", state); }
-@after { popMsg(state); }
- : KW_EXPORT
- KW_TABLE (tab=tableOrPartition)
- KW_TO (path=StringLiteral)
- replicationClause?
- -> ^(TOK_EXPORT $tab $path replicationClause?)
- ;
-
-importStatement
-@init { pushMsg("import statement", state); }
-@after { popMsg(state); }
- : KW_IMPORT
- ((ext=KW_EXTERNAL)? KW_TABLE (tab=tableOrPartition))?
- KW_FROM (path=StringLiteral)
- tableLocation?
- -> ^(TOK_IMPORT $path $tab? $ext? tableLocation?)
- ;
-
-ddlStatement
-@init { pushMsg("ddl statement", state); }
-@after { popMsg(state); }
- : createDatabaseStatement
- | switchDatabaseStatement
- | dropDatabaseStatement
- | createTableStatement
- | dropTableStatement
- | truncateTableStatement
- | alterStatement
- | descStatement
- | refreshStatement
- | showStatement
- | metastoreCheck
- | createViewStatement
- | dropViewStatement
- | createFunctionStatement
- | createMacroStatement
- | createIndexStatement
- | dropIndexStatement
- | dropFunctionStatement
- | reloadFunctionStatement
- | dropMacroStatement
- | analyzeStatement
- | lockStatement
- | unlockStatement
- | lockDatabase
- | unlockDatabase
- | createRoleStatement
- | dropRoleStatement
- | (grantPrivileges) => grantPrivileges
- | (revokePrivileges) => revokePrivileges
- | showGrants
- | showRoleGrants
- | showRolePrincipals
- | showRoles
- | grantRole
- | revokeRole
- | setRole
- | showCurrentRole
- ;
-
-ifExists
-@init { pushMsg("if exists clause", state); }
-@after { popMsg(state); }
- : KW_IF KW_EXISTS
- -> ^(TOK_IFEXISTS)
- ;
-
-restrictOrCascade
-@init { pushMsg("restrict or cascade clause", state); }
-@after { popMsg(state); }
- : KW_RESTRICT
- -> ^(TOK_RESTRICT)
- | KW_CASCADE
- -> ^(TOK_CASCADE)
- ;
-
-ifNotExists
-@init { pushMsg("if not exists clause", state); }
-@after { popMsg(state); }
- : KW_IF KW_NOT KW_EXISTS
- -> ^(TOK_IFNOTEXISTS)
- ;
-
-storedAsDirs
-@init { pushMsg("stored as directories", state); }
-@after { popMsg(state); }
- : KW_STORED KW_AS KW_DIRECTORIES
- -> ^(TOK_STOREDASDIRS)
- ;
-
-orReplace
-@init { pushMsg("or replace clause", state); }
-@after { popMsg(state); }
- : KW_OR KW_REPLACE
- -> ^(TOK_ORREPLACE)
- ;
-
-createDatabaseStatement
-@init { pushMsg("create database statement", state); }
-@after { popMsg(state); }
- : KW_CREATE (KW_DATABASE|KW_SCHEMA)
- ifNotExists?
- name=identifier
- databaseComment?
- dbLocation?
- (KW_WITH KW_DBPROPERTIES dbprops=dbProperties)?
- -> ^(TOK_CREATEDATABASE $name ifNotExists? dbLocation? databaseComment? $dbprops?)
- ;
-
-dbLocation
-@init { pushMsg("database location specification", state); }
-@after { popMsg(state); }
- :
- KW_LOCATION locn=StringLiteral -> ^(TOK_DATABASELOCATION $locn)
- ;
-
-dbProperties
-@init { pushMsg("dbproperties", state); }
-@after { popMsg(state); }
- :
- LPAREN dbPropertiesList RPAREN -> ^(TOK_DATABASEPROPERTIES dbPropertiesList)
- ;
-
-dbPropertiesList
-@init { pushMsg("database properties list", state); }
-@after { popMsg(state); }
- :
- keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_DBPROPLIST keyValueProperty+)
- ;
-
-
-switchDatabaseStatement
-@init { pushMsg("switch database statement", state); }
-@after { popMsg(state); }
- : KW_USE identifier
- -> ^(TOK_SWITCHDATABASE identifier)
- ;
-
-dropDatabaseStatement
-@init { pushMsg("drop database statement", state); }
-@after { popMsg(state); }
- : KW_DROP (KW_DATABASE|KW_SCHEMA) ifExists? identifier restrictOrCascade?
- -> ^(TOK_DROPDATABASE identifier ifExists? restrictOrCascade?)
- ;
-
-databaseComment
-@init { pushMsg("database's comment", state); }
-@after { popMsg(state); }
- : KW_COMMENT comment=StringLiteral
- -> ^(TOK_DATABASECOMMENT $comment)
- ;
-
-createTableStatement
-@init { pushMsg("create table statement", state); }
-@after { popMsg(state); }
- : KW_CREATE (temp=KW_TEMPORARY)? (ext=KW_EXTERNAL)? KW_TABLE ifNotExists? name=tableName
- (
- like=KW_LIKE likeName=tableName
- tableRowFormat?
- tableFileFormat?
- tableLocation?
- tablePropertiesPrefixed?
- -> ^(TOK_CREATETABLE $name $temp? $ext? ifNotExists?
- ^(TOK_LIKETABLE $likeName?)
- tableRowFormat?
- tableFileFormat?
- tableLocation?
- tablePropertiesPrefixed?
- )
- |
- (tableProvider) => tableProvider
- tableOpts?
- (KW_AS selectStatementWithCTE)?
- -> ^(TOK_CREATETABLEUSING $name $temp? ifNotExists?
- tableProvider
- tableOpts?
- selectStatementWithCTE?
- )
- | (LPAREN columnNameTypeList RPAREN)?
- (p=tableProvider?)
- tableOpts?
- tableComment?
- tablePartition?
- tableBuckets?
- tableSkewed?
- tableRowFormat?
- tableFileFormat?
- tableLocation?
- tablePropertiesPrefixed?
- (KW_AS selectStatementWithCTE)?
- -> {p != null}?
- ^(TOK_CREATETABLEUSING $name $temp? ifNotExists?
- columnNameTypeList?
- $p
- tableOpts?
- selectStatementWithCTE?
- )
- ->
- ^(TOK_CREATETABLE $name $temp? $ext? ifNotExists?
- ^(TOK_LIKETABLE $likeName?)
- columnNameTypeList?
- tableComment?
- tablePartition?
- tableBuckets?
- tableSkewed?
- tableRowFormat?
- tableFileFormat?
- tableLocation?
- tablePropertiesPrefixed?
- selectStatementWithCTE?
- )
- )
- ;
-
-truncateTableStatement
-@init { pushMsg("truncate table statement", state); }
-@after { popMsg(state); }
- : KW_TRUNCATE KW_TABLE tablePartitionPrefix (KW_COLUMNS LPAREN columnNameList RPAREN)? -> ^(TOK_TRUNCATETABLE tablePartitionPrefix columnNameList?);
-
-createIndexStatement
-@init { pushMsg("create index statement", state);}
-@after {popMsg(state);}
- : KW_CREATE KW_INDEX indexName=identifier
- KW_ON KW_TABLE tab=tableName LPAREN indexedCols=columnNameList RPAREN
- KW_AS typeName=StringLiteral
- autoRebuild?
- indexPropertiesPrefixed?
- indexTblName?
- tableRowFormat?
- tableFileFormat?
- tableLocation?
- tablePropertiesPrefixed?
- indexComment?
- ->^(TOK_CREATEINDEX $indexName $typeName $tab $indexedCols
- autoRebuild?
- indexPropertiesPrefixed?
- indexTblName?
- tableRowFormat?
- tableFileFormat?
- tableLocation?
- tablePropertiesPrefixed?
- indexComment?)
- ;
-
-indexComment
-@init { pushMsg("comment on an index", state);}
-@after {popMsg(state);}
- :
- KW_COMMENT comment=StringLiteral -> ^(TOK_INDEXCOMMENT $comment)
- ;
-
-autoRebuild
-@init { pushMsg("auto rebuild index", state);}
-@after {popMsg(state);}
- : KW_WITH KW_DEFERRED KW_REBUILD
- ->^(TOK_DEFERRED_REBUILDINDEX)
- ;
-
-indexTblName
-@init { pushMsg("index table name", state);}
-@after {popMsg(state);}
- : KW_IN KW_TABLE indexTbl=tableName
- ->^(TOK_CREATEINDEX_INDEXTBLNAME $indexTbl)
- ;
-
-indexPropertiesPrefixed
-@init { pushMsg("table properties with prefix", state); }
-@after { popMsg(state); }
- :
- KW_IDXPROPERTIES! indexProperties
- ;
-
-indexProperties
-@init { pushMsg("index properties", state); }
-@after { popMsg(state); }
- :
- LPAREN indexPropertiesList RPAREN -> ^(TOK_INDEXPROPERTIES indexPropertiesList)
- ;
-
-indexPropertiesList
-@init { pushMsg("index properties list", state); }
-@after { popMsg(state); }
- :
- keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_INDEXPROPLIST keyValueProperty+)
- ;
-
-dropIndexStatement
-@init { pushMsg("drop index statement", state);}
-@after {popMsg(state);}
- : KW_DROP KW_INDEX ifExists? indexName=identifier KW_ON tab=tableName
- ->^(TOK_DROPINDEX $indexName $tab ifExists?)
- ;
-
-dropTableStatement
-@init { pushMsg("drop statement", state); }
-@after { popMsg(state); }
- : KW_DROP KW_TABLE ifExists? tableName KW_PURGE? replicationClause?
- -> ^(TOK_DROPTABLE tableName ifExists? KW_PURGE? replicationClause?)
- ;
-
-alterStatement
-@init { pushMsg("alter statement", state); }
-@after { popMsg(state); }
- : KW_ALTER KW_TABLE tableName alterTableStatementSuffix -> ^(TOK_ALTERTABLE tableName alterTableStatementSuffix)
- | KW_ALTER KW_VIEW tableName KW_AS? alterViewStatementSuffix -> ^(TOK_ALTERVIEW tableName alterViewStatementSuffix)
- | KW_ALTER KW_INDEX alterIndexStatementSuffix -> alterIndexStatementSuffix
- | KW_ALTER (KW_DATABASE|KW_SCHEMA) alterDatabaseStatementSuffix -> alterDatabaseStatementSuffix
- ;
-
-alterTableStatementSuffix
-@init { pushMsg("alter table statement", state); }
-@after { popMsg(state); }
- : (alterStatementSuffixRename[true]) => alterStatementSuffixRename[true]
- | alterStatementSuffixDropPartitions[true]
- | alterStatementSuffixAddPartitions[true]
- | alterStatementSuffixTouch
- | alterStatementSuffixArchive
- | alterStatementSuffixUnArchive
- | alterStatementSuffixProperties
- | alterStatementSuffixSkewedby
- | alterStatementSuffixExchangePartition
- | alterStatementPartitionKeyType
- | partitionSpec? alterTblPartitionStatementSuffix -> alterTblPartitionStatementSuffix partitionSpec?
- ;
-
-alterTblPartitionStatementSuffix
-@init {pushMsg("alter table partition statement suffix", state);}
-@after {popMsg(state);}
- : alterStatementSuffixFileFormat
- | alterStatementSuffixLocation
- | alterStatementSuffixMergeFiles
- | alterStatementSuffixSerdeProperties
- | alterStatementSuffixRenamePart
- | alterStatementSuffixBucketNum
- | alterTblPartitionStatementSuffixSkewedLocation
- | alterStatementSuffixClusterbySortby
- | alterStatementSuffixCompact
- | alterStatementSuffixUpdateStatsCol
- | alterStatementSuffixRenameCol
- | alterStatementSuffixAddCol
- ;
-
-alterStatementPartitionKeyType
-@init {msgs.push("alter partition key type"); }
-@after {msgs.pop();}
- : KW_PARTITION KW_COLUMN LPAREN columnNameType RPAREN
- -> ^(TOK_ALTERTABLE_PARTCOLTYPE columnNameType)
- ;
-
-alterViewStatementSuffix
-@init { pushMsg("alter view statement", state); }
-@after { popMsg(state); }
- : alterViewSuffixProperties
- | alterStatementSuffixRename[false]
- | alterStatementSuffixAddPartitions[false]
- | alterStatementSuffixDropPartitions[false]
- | selectStatementWithCTE
- ;
-
-alterIndexStatementSuffix
-@init { pushMsg("alter index statement", state); }
-@after { popMsg(state); }
- : indexName=identifier KW_ON tableName partitionSpec?
- (
- KW_REBUILD
- ->^(TOK_ALTERINDEX_REBUILD tableName $indexName partitionSpec?)
- |
- KW_SET KW_IDXPROPERTIES
- indexProperties
- ->^(TOK_ALTERINDEX_PROPERTIES tableName $indexName indexProperties)
- )
- ;
-
-alterDatabaseStatementSuffix
-@init { pushMsg("alter database statement", state); }
-@after { popMsg(state); }
- : alterDatabaseSuffixProperties
- | alterDatabaseSuffixSetOwner
- ;
-
-alterDatabaseSuffixProperties
-@init { pushMsg("alter database properties statement", state); }
-@after { popMsg(state); }
- : name=identifier KW_SET KW_DBPROPERTIES dbProperties
- -> ^(TOK_ALTERDATABASE_PROPERTIES $name dbProperties)
- ;
-
-alterDatabaseSuffixSetOwner
-@init { pushMsg("alter database set owner", state); }
-@after { popMsg(state); }
- : dbName=identifier KW_SET KW_OWNER principalName
- -> ^(TOK_ALTERDATABASE_OWNER $dbName principalName)
- ;
-
-alterStatementSuffixRename[boolean table]
-@init { pushMsg("rename statement", state); }
-@after { popMsg(state); }
- : KW_RENAME KW_TO tableName
- -> { table }? ^(TOK_ALTERTABLE_RENAME tableName)
- -> ^(TOK_ALTERVIEW_RENAME tableName)
- ;
-
-alterStatementSuffixAddCol
-@init { pushMsg("add column statement", state); }
-@after { popMsg(state); }
- : (add=KW_ADD | replace=KW_REPLACE) KW_COLUMNS LPAREN columnNameTypeList RPAREN restrictOrCascade?
- -> {$add != null}? ^(TOK_ALTERTABLE_ADDCOLS columnNameTypeList restrictOrCascade?)
- -> ^(TOK_ALTERTABLE_REPLACECOLS columnNameTypeList restrictOrCascade?)
- ;
-
-alterStatementSuffixRenameCol
-@init { pushMsg("rename column name", state); }
-@after { popMsg(state); }
- : KW_CHANGE KW_COLUMN? oldName=identifier newName=identifier colType (KW_COMMENT comment=StringLiteral)? alterStatementChangeColPosition? restrictOrCascade?
- ->^(TOK_ALTERTABLE_RENAMECOL $oldName $newName colType $comment? alterStatementChangeColPosition? restrictOrCascade?)
- ;
-
-alterStatementSuffixUpdateStatsCol
-@init { pushMsg("update column statistics", state); }
-@after { popMsg(state); }
- : KW_UPDATE KW_STATISTICS KW_FOR KW_COLUMN? colName=identifier KW_SET tableProperties (KW_COMMENT comment=StringLiteral)?
- ->^(TOK_ALTERTABLE_UPDATECOLSTATS $colName tableProperties $comment?)
- ;
-
-alterStatementChangeColPosition
- : first=KW_FIRST|KW_AFTER afterCol=identifier
- ->{$first != null}? ^(TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION )
- -> ^(TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION $afterCol)
- ;
-
-alterStatementSuffixAddPartitions[boolean table]
-@init { pushMsg("add partition statement", state); }
-@after { popMsg(state); }
- : KW_ADD ifNotExists? alterStatementSuffixAddPartitionsElement+
- -> { table }? ^(TOK_ALTERTABLE_ADDPARTS ifNotExists? alterStatementSuffixAddPartitionsElement+)
- -> ^(TOK_ALTERVIEW_ADDPARTS ifNotExists? alterStatementSuffixAddPartitionsElement+)
- ;
-
-alterStatementSuffixAddPartitionsElement
- : partitionSpec partitionLocation?
- ;
-
-alterStatementSuffixTouch
-@init { pushMsg("touch statement", state); }
-@after { popMsg(state); }
- : KW_TOUCH (partitionSpec)*
- -> ^(TOK_ALTERTABLE_TOUCH (partitionSpec)*)
- ;
-
-alterStatementSuffixArchive
-@init { pushMsg("archive statement", state); }
-@after { popMsg(state); }
- : KW_ARCHIVE (partitionSpec)*
- -> ^(TOK_ALTERTABLE_ARCHIVE (partitionSpec)*)
- ;
-
-alterStatementSuffixUnArchive
-@init { pushMsg("unarchive statement", state); }
-@after { popMsg(state); }
- : KW_UNARCHIVE (partitionSpec)*
- -> ^(TOK_ALTERTABLE_UNARCHIVE (partitionSpec)*)
- ;
-
-partitionLocation
-@init { pushMsg("partition location", state); }
-@after { popMsg(state); }
- :
- KW_LOCATION locn=StringLiteral -> ^(TOK_PARTITIONLOCATION $locn)
- ;
-
-alterStatementSuffixDropPartitions[boolean table]
-@init { pushMsg("drop partition statement", state); }
-@after { popMsg(state); }
- : KW_DROP ifExists? dropPartitionSpec (COMMA dropPartitionSpec)* KW_PURGE? replicationClause?
- -> { table }? ^(TOK_ALTERTABLE_DROPPARTS dropPartitionSpec+ ifExists? KW_PURGE? replicationClause?)
- -> ^(TOK_ALTERVIEW_DROPPARTS dropPartitionSpec+ ifExists? replicationClause?)
- ;
-
-alterStatementSuffixProperties
-@init { pushMsg("alter properties statement", state); }
-@after { popMsg(state); }
- : KW_SET KW_TBLPROPERTIES tableProperties
- -> ^(TOK_ALTERTABLE_PROPERTIES tableProperties)
- | KW_UNSET KW_TBLPROPERTIES ifExists? tableProperties
- -> ^(TOK_ALTERTABLE_DROPPROPERTIES tableProperties ifExists?)
- ;
-
-alterViewSuffixProperties
-@init { pushMsg("alter view properties statement", state); }
-@after { popMsg(state); }
- : KW_SET KW_TBLPROPERTIES tableProperties
- -> ^(TOK_ALTERVIEW_PROPERTIES tableProperties)
- | KW_UNSET KW_TBLPROPERTIES ifExists? tableProperties
- -> ^(TOK_ALTERVIEW_DROPPROPERTIES tableProperties ifExists?)
- ;
-
-alterStatementSuffixSerdeProperties
-@init { pushMsg("alter serdes statement", state); }
-@after { popMsg(state); }
- : KW_SET KW_SERDE serdeName=StringLiteral (KW_WITH KW_SERDEPROPERTIES tableProperties)?
- -> ^(TOK_ALTERTABLE_SERIALIZER $serdeName tableProperties?)
- | KW_SET KW_SERDEPROPERTIES tableProperties
- -> ^(TOK_ALTERTABLE_SERDEPROPERTIES tableProperties)
- ;
-
-tablePartitionPrefix
-@init {pushMsg("table partition prefix", state);}
-@after {popMsg(state);}
- : tableName partitionSpec?
- ->^(TOK_TABLE_PARTITION tableName partitionSpec?)
- ;
-
-alterStatementSuffixFileFormat
-@init {pushMsg("alter fileformat statement", state); }
-@after {popMsg(state);}
- : KW_SET KW_FILEFORMAT fileFormat
- -> ^(TOK_ALTERTABLE_FILEFORMAT fileFormat)
- ;
-
-alterStatementSuffixClusterbySortby
-@init {pushMsg("alter partition cluster by sort by statement", state);}
-@after {popMsg(state);}
- : KW_NOT KW_CLUSTERED -> ^(TOK_ALTERTABLE_CLUSTER_SORT TOK_NOT_CLUSTERED)
- | KW_NOT KW_SORTED -> ^(TOK_ALTERTABLE_CLUSTER_SORT TOK_NOT_SORTED)
- | tableBuckets -> ^(TOK_ALTERTABLE_CLUSTER_SORT tableBuckets)
- ;
-
-alterTblPartitionStatementSuffixSkewedLocation
-@init {pushMsg("alter partition skewed location", state);}
-@after {popMsg(state);}
- : KW_SET KW_SKEWED KW_LOCATION skewedLocations
- -> ^(TOK_ALTERTABLE_SKEWED_LOCATION skewedLocations)
- ;
-
-skewedLocations
-@init { pushMsg("skewed locations", state); }
-@after { popMsg(state); }
- :
- LPAREN skewedLocationsList RPAREN -> ^(TOK_SKEWED_LOCATIONS skewedLocationsList)
- ;
-
-skewedLocationsList
-@init { pushMsg("skewed locations list", state); }
-@after { popMsg(state); }
- :
- skewedLocationMap (COMMA skewedLocationMap)* -> ^(TOK_SKEWED_LOCATION_LIST skewedLocationMap+)
- ;
-
-skewedLocationMap
-@init { pushMsg("specifying skewed location map", state); }
-@after { popMsg(state); }
- :
- key=skewedValueLocationElement EQUAL value=StringLiteral -> ^(TOK_SKEWED_LOCATION_MAP $key $value)
- ;
-
-alterStatementSuffixLocation
-@init {pushMsg("alter location", state);}
-@after {popMsg(state);}
- : KW_SET KW_LOCATION newLoc=StringLiteral
- -> ^(TOK_ALTERTABLE_LOCATION $newLoc)
- ;
-
-
-alterStatementSuffixSkewedby
-@init {pushMsg("alter skewed by statement", state);}
-@after{popMsg(state);}
- : tableSkewed
- ->^(TOK_ALTERTABLE_SKEWED tableSkewed)
- |
- KW_NOT KW_SKEWED
- ->^(TOK_ALTERTABLE_SKEWED)
- |
- KW_NOT storedAsDirs
- ->^(TOK_ALTERTABLE_SKEWED storedAsDirs)
- ;
-
-alterStatementSuffixExchangePartition
-@init {pushMsg("alter exchange partition", state);}
-@after{popMsg(state);}
- : KW_EXCHANGE partitionSpec KW_WITH KW_TABLE exchangename=tableName
- -> ^(TOK_ALTERTABLE_EXCHANGEPARTITION partitionSpec $exchangename)
- ;
-
-alterStatementSuffixRenamePart
-@init { pushMsg("alter table rename partition statement", state); }
-@after { popMsg(state); }
- : KW_RENAME KW_TO partitionSpec
- ->^(TOK_ALTERTABLE_RENAMEPART partitionSpec)
- ;
-
-alterStatementSuffixStatsPart
-@init { pushMsg("alter table stats partition statement", state); }
-@after { popMsg(state); }
- : KW_UPDATE KW_STATISTICS KW_FOR KW_COLUMN? colName=identifier KW_SET tableProperties (KW_COMMENT comment=StringLiteral)?
- ->^(TOK_ALTERTABLE_UPDATECOLSTATS $colName tableProperties $comment?)
- ;
-
-alterStatementSuffixMergeFiles
-@init { pushMsg("", state); }
-@after { popMsg(state); }
- : KW_CONCATENATE
- -> ^(TOK_ALTERTABLE_MERGEFILES)
- ;
-
-alterStatementSuffixBucketNum
-@init { pushMsg("", state); }
-@after { popMsg(state); }
- : KW_INTO num=Number KW_BUCKETS
- -> ^(TOK_ALTERTABLE_BUCKETS $num)
- ;
-
-alterStatementSuffixCompact
-@init { msgs.push("compaction request"); }
-@after { msgs.pop(); }
- : KW_COMPACT compactType=StringLiteral
- -> ^(TOK_ALTERTABLE_COMPACT $compactType)
- ;
-
-
-fileFormat
-@init { pushMsg("file format specification", state); }
-@after { popMsg(state); }
- : KW_INPUTFORMAT inFmt=StringLiteral KW_OUTPUTFORMAT outFmt=StringLiteral KW_SERDE serdeCls=StringLiteral (KW_INPUTDRIVER inDriver=StringLiteral KW_OUTPUTDRIVER outDriver=StringLiteral)?
- -> ^(TOK_TABLEFILEFORMAT $inFmt $outFmt $serdeCls $inDriver? $outDriver?)
- | genericSpec=identifier -> ^(TOK_FILEFORMAT_GENERIC $genericSpec)
- ;
-
-tabTypeExpr
-@init { pushMsg("specifying table types", state); }
-@after { popMsg(state); }
- : identifier (DOT^ identifier)?
- (identifier (DOT^
- (
- (KW_ELEM_TYPE) => KW_ELEM_TYPE
- |
- (KW_KEY_TYPE) => KW_KEY_TYPE
- |
- (KW_VALUE_TYPE) => KW_VALUE_TYPE
- | identifier
- ))*
- )?
- ;
-
-partTypeExpr
-@init { pushMsg("specifying table partitions", state); }
-@after { popMsg(state); }
- : tabTypeExpr partitionSpec? -> ^(TOK_TABTYPE tabTypeExpr partitionSpec?)
- ;
-
-tabPartColTypeExpr
-@init { pushMsg("specifying table partitions columnName", state); }
-@after { popMsg(state); }
- : tableName partitionSpec? extColumnName? -> ^(TOK_TABTYPE tableName partitionSpec? extColumnName?)
- ;
-
-refreshStatement
-@init { pushMsg("refresh statement", state); }
-@after { popMsg(state); }
- :
- KW_REFRESH KW_TABLE tableName -> ^(TOK_REFRESHTABLE tableName)
- ;
-
-descStatement
-@init { pushMsg("describe statement", state); }
-@after { popMsg(state); }
- :
- (KW_DESCRIBE|KW_DESC)
- (
- (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) KW_EXTENDED? (dbName=identifier) -> ^(TOK_DESCDATABASE $dbName KW_EXTENDED?)
- |
- (KW_FUNCTION) => KW_FUNCTION KW_EXTENDED? (name=descFuncNames) -> ^(TOK_DESCFUNCTION $name KW_EXTENDED?)
- |
- (KW_FORMATTED|KW_EXTENDED|KW_PRETTY) => ((descOptions=KW_FORMATTED|descOptions=KW_EXTENDED|descOptions=KW_PRETTY) parttype=tabPartColTypeExpr) -> ^(TOK_DESCTABLE $parttype $descOptions)
- |
- parttype=tabPartColTypeExpr -> ^(TOK_DESCTABLE $parttype)
- )
- ;
-
-analyzeStatement
-@init { pushMsg("analyze statement", state); }
-@after { popMsg(state); }
- : KW_ANALYZE KW_TABLE (parttype=tableOrPartition) KW_COMPUTE KW_STATISTICS ((noscan=KW_NOSCAN) | (partialscan=KW_PARTIALSCAN)
- | (KW_FOR KW_COLUMNS (statsColumnName=columnNameList)?))?
- -> ^(TOK_ANALYZE $parttype $noscan? $partialscan? KW_COLUMNS? $statsColumnName?)
- ;
-
-showStatement
-@init { pushMsg("show statement", state); }
-@after { popMsg(state); }
- : KW_SHOW (KW_DATABASES|KW_SCHEMAS) (KW_LIKE showStmtIdentifier)? -> ^(TOK_SHOWDATABASES showStmtIdentifier?)
- | KW_SHOW KW_TABLES ((KW_FROM|KW_IN) db_name=identifier)? (KW_LIKE showStmtIdentifier|showStmtIdentifier)? -> ^(TOK_SHOWTABLES ^(TOK_FROM $db_name)? showStmtIdentifier?)
- | KW_SHOW KW_COLUMNS (KW_FROM|KW_IN) tableName ((KW_FROM|KW_IN) db_name=identifier)?
- -> ^(TOK_SHOWCOLUMNS tableName $db_name?)
- | KW_SHOW KW_FUNCTIONS (KW_LIKE showFunctionIdentifier|showFunctionIdentifier)? -> ^(TOK_SHOWFUNCTIONS KW_LIKE? showFunctionIdentifier?)
- | KW_SHOW KW_PARTITIONS tabName=tableName partitionSpec? -> ^(TOK_SHOWPARTITIONS $tabName partitionSpec?)
- | KW_SHOW KW_CREATE (
- (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) db_name=identifier -> ^(TOK_SHOW_CREATEDATABASE $db_name)
- |
- KW_TABLE tabName=tableName -> ^(TOK_SHOW_CREATETABLE $tabName)
- )
- | KW_SHOW KW_TABLE KW_EXTENDED ((KW_FROM|KW_IN) db_name=identifier)? KW_LIKE showStmtIdentifier partitionSpec?
- -> ^(TOK_SHOW_TABLESTATUS showStmtIdentifier $db_name? partitionSpec?)
- | KW_SHOW KW_TBLPROPERTIES tableName (LPAREN prptyName=StringLiteral RPAREN)? -> ^(TOK_SHOW_TBLPROPERTIES tableName $prptyName?)
- | KW_SHOW KW_LOCKS
- (
- (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) (isExtended=KW_EXTENDED)? -> ^(TOK_SHOWDBLOCKS $dbName $isExtended?)
- |
- (parttype=partTypeExpr)? (isExtended=KW_EXTENDED)? -> ^(TOK_SHOWLOCKS $parttype? $isExtended?)
- )
- | KW_SHOW (showOptions=KW_FORMATTED)? (KW_INDEX|KW_INDEXES) KW_ON showStmtIdentifier ((KW_FROM|KW_IN) db_name=identifier)?
- -> ^(TOK_SHOWINDEXES showStmtIdentifier $showOptions? $db_name?)
- | KW_SHOW KW_COMPACTIONS -> ^(TOK_SHOW_COMPACTIONS)
- | KW_SHOW KW_TRANSACTIONS -> ^(TOK_SHOW_TRANSACTIONS)
- | KW_SHOW KW_CONF StringLiteral -> ^(TOK_SHOWCONF StringLiteral)
- ;
-
-lockStatement
-@init { pushMsg("lock statement", state); }
-@after { popMsg(state); }
- : KW_LOCK KW_TABLE tableName partitionSpec? lockMode -> ^(TOK_LOCKTABLE tableName lockMode partitionSpec?)
- ;
-
-lockDatabase
-@init { pushMsg("lock database statement", state); }
-@after { popMsg(state); }
- : KW_LOCK (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) lockMode -> ^(TOK_LOCKDB $dbName lockMode)
- ;
-
-lockMode
-@init { pushMsg("lock mode", state); }
-@after { popMsg(state); }
- : KW_SHARED | KW_EXCLUSIVE
- ;
-
-unlockStatement
-@init { pushMsg("unlock statement", state); }
-@after { popMsg(state); }
- : KW_UNLOCK KW_TABLE tableName partitionSpec? -> ^(TOK_UNLOCKTABLE tableName partitionSpec?)
- ;
-
-unlockDatabase
-@init { pushMsg("unlock database statement", state); }
-@after { popMsg(state); }
- : KW_UNLOCK (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) -> ^(TOK_UNLOCKDB $dbName)
- ;
-
-createRoleStatement
-@init { pushMsg("create role", state); }
-@after { popMsg(state); }
- : KW_CREATE KW_ROLE roleName=identifier
- -> ^(TOK_CREATEROLE $roleName)
- ;
-
-dropRoleStatement
-@init {pushMsg("drop role", state);}
-@after {popMsg(state);}
- : KW_DROP KW_ROLE roleName=identifier
- -> ^(TOK_DROPROLE $roleName)
- ;
-
-grantPrivileges
-@init {pushMsg("grant privileges", state);}
-@after {popMsg(state);}
- : KW_GRANT privList=privilegeList
- privilegeObject?
- KW_TO principalSpecification
- withGrantOption?
- -> ^(TOK_GRANT $privList principalSpecification privilegeObject? withGrantOption?)
- ;
-
-revokePrivileges
-@init {pushMsg("revoke privileges", state);}
-@afer {popMsg(state);}
- : KW_REVOKE grantOptionFor? privilegeList privilegeObject? KW_FROM principalSpecification
- -> ^(TOK_REVOKE privilegeList principalSpecification privilegeObject? grantOptionFor?)
- ;
-
-grantRole
-@init {pushMsg("grant role", state);}
-@after {popMsg(state);}
- : KW_GRANT KW_ROLE? identifier (COMMA identifier)* KW_TO principalSpecification withAdminOption?
- -> ^(TOK_GRANT_ROLE principalSpecification withAdminOption? identifier+)
- ;
-
-revokeRole
-@init {pushMsg("revoke role", state);}
-@after {popMsg(state);}
- : KW_REVOKE adminOptionFor? KW_ROLE? identifier (COMMA identifier)* KW_FROM principalSpecification
- -> ^(TOK_REVOKE_ROLE principalSpecification adminOptionFor? identifier+)
- ;
-
-showRoleGrants
-@init {pushMsg("show role grants", state);}
-@after {popMsg(state);}
- : KW_SHOW KW_ROLE KW_GRANT principalName
- -> ^(TOK_SHOW_ROLE_GRANT principalName)
- ;
-
-
-showRoles
-@init {pushMsg("show roles", state);}
-@after {popMsg(state);}
- : KW_SHOW KW_ROLES
- -> ^(TOK_SHOW_ROLES)
- ;
-
-showCurrentRole
-@init {pushMsg("show current role", state);}
-@after {popMsg(state);}
- : KW_SHOW KW_CURRENT KW_ROLES
- -> ^(TOK_SHOW_SET_ROLE)
- ;
-
-setRole
-@init {pushMsg("set role", state);}
-@after {popMsg(state);}
- : KW_SET KW_ROLE
- (
- (KW_ALL) => (all=KW_ALL) -> ^(TOK_SHOW_SET_ROLE Identifier[$all.text])
- |
- (KW_NONE) => (none=KW_NONE) -> ^(TOK_SHOW_SET_ROLE Identifier[$none.text])
- |
- identifier -> ^(TOK_SHOW_SET_ROLE identifier)
- )
- ;
-
-showGrants
-@init {pushMsg("show grants", state);}
-@after {popMsg(state);}
- : KW_SHOW KW_GRANT principalName? (KW_ON privilegeIncludeColObject)?
- -> ^(TOK_SHOW_GRANT principalName? privilegeIncludeColObject?)
- ;
-
-showRolePrincipals
-@init {pushMsg("show role principals", state);}
-@after {popMsg(state);}
- : KW_SHOW KW_PRINCIPALS roleName=identifier
- -> ^(TOK_SHOW_ROLE_PRINCIPALS $roleName)
- ;
-
-
-privilegeIncludeColObject
-@init {pushMsg("privilege object including columns", state);}
-@after {popMsg(state);}
- : (KW_ALL) => KW_ALL -> ^(TOK_RESOURCE_ALL)
- | privObjectCols -> ^(TOK_PRIV_OBJECT_COL privObjectCols)
- ;
-
-privilegeObject
-@init {pushMsg("privilege object", state);}
-@after {popMsg(state);}
- : KW_ON privObject -> ^(TOK_PRIV_OBJECT privObject)
- ;
-
-// database or table type. Type is optional, default type is table
-privObject
- : (KW_DATABASE|KW_SCHEMA) identifier -> ^(TOK_DB_TYPE identifier)
- | KW_TABLE? tableName partitionSpec? -> ^(TOK_TABLE_TYPE tableName partitionSpec?)
- | KW_URI (path=StringLiteral) -> ^(TOK_URI_TYPE $path)
- | KW_SERVER identifier -> ^(TOK_SERVER_TYPE identifier)
- ;
-
-privObjectCols
- : (KW_DATABASE|KW_SCHEMA) identifier -> ^(TOK_DB_TYPE identifier)
- | KW_TABLE? tableName (LPAREN cols=columnNameList RPAREN)? partitionSpec? -> ^(TOK_TABLE_TYPE tableName $cols? partitionSpec?)
- | KW_URI (path=StringLiteral) -> ^(TOK_URI_TYPE $path)
- | KW_SERVER identifier -> ^(TOK_SERVER_TYPE identifier)
- ;
-
-privilegeList
-@init {pushMsg("grant privilege list", state);}
-@after {popMsg(state);}
- : privlegeDef (COMMA privlegeDef)*
- -> ^(TOK_PRIVILEGE_LIST privlegeDef+)
- ;
-
-privlegeDef
-@init {pushMsg("grant privilege", state);}
-@after {popMsg(state);}
- : privilegeType (LPAREN cols=columnNameList RPAREN)?
- -> ^(TOK_PRIVILEGE privilegeType $cols?)
- ;
-
-privilegeType
-@init {pushMsg("privilege type", state);}
-@after {popMsg(state);}
- : KW_ALL -> ^(TOK_PRIV_ALL)
- | KW_ALTER -> ^(TOK_PRIV_ALTER_METADATA)
- | KW_UPDATE -> ^(TOK_PRIV_ALTER_DATA)
- | KW_CREATE -> ^(TOK_PRIV_CREATE)
- | KW_DROP -> ^(TOK_PRIV_DROP)
- | KW_INDEX -> ^(TOK_PRIV_INDEX)
- | KW_LOCK -> ^(TOK_PRIV_LOCK)
- | KW_SELECT -> ^(TOK_PRIV_SELECT)
- | KW_SHOW_DATABASE -> ^(TOK_PRIV_SHOW_DATABASE)
- | KW_INSERT -> ^(TOK_PRIV_INSERT)
- | KW_DELETE -> ^(TOK_PRIV_DELETE)
- ;
-
-principalSpecification
-@init { pushMsg("user/group/role name list", state); }
-@after { popMsg(state); }
- : principalName (COMMA principalName)* -> ^(TOK_PRINCIPAL_NAME principalName+)
- ;
-
-principalName
-@init {pushMsg("user|group|role name", state);}
-@after {popMsg(state);}
- : KW_USER principalIdentifier -> ^(TOK_USER principalIdentifier)
- | KW_GROUP principalIdentifier -> ^(TOK_GROUP principalIdentifier)
- | KW_ROLE identifier -> ^(TOK_ROLE identifier)
- ;
-
-withGrantOption
-@init {pushMsg("with grant option", state);}
-@after {popMsg(state);}
- : KW_WITH KW_GRANT KW_OPTION
- -> ^(TOK_GRANT_WITH_OPTION)
- ;
-
-grantOptionFor
-@init {pushMsg("grant option for", state);}
-@after {popMsg(state);}
- : KW_GRANT KW_OPTION KW_FOR
- -> ^(TOK_GRANT_OPTION_FOR)
-;
-
-adminOptionFor
-@init {pushMsg("admin option for", state);}
-@after {popMsg(state);}
- : KW_ADMIN KW_OPTION KW_FOR
- -> ^(TOK_ADMIN_OPTION_FOR)
-;
-
-withAdminOption
-@init {pushMsg("with admin option", state);}
-@after {popMsg(state);}
- : KW_WITH KW_ADMIN KW_OPTION
- -> ^(TOK_GRANT_WITH_ADMIN_OPTION)
- ;
-
-metastoreCheck
-@init { pushMsg("metastore check statement", state); }
-@after { popMsg(state); }
- : KW_MSCK (repair=KW_REPAIR)? (KW_TABLE tableName partitionSpec? (COMMA partitionSpec)*)?
- -> ^(TOK_MSCK $repair? (tableName partitionSpec*)?)
- ;
-
-resourceList
-@init { pushMsg("resource list", state); }
-@after { popMsg(state); }
- :
- resource (COMMA resource)* -> ^(TOK_RESOURCE_LIST resource+)
- ;
-
-resource
-@init { pushMsg("resource", state); }
-@after { popMsg(state); }
- :
- resType=resourceType resPath=StringLiteral -> ^(TOK_RESOURCE_URI $resType $resPath)
- ;
-
-resourceType
-@init { pushMsg("resource type", state); }
-@after { popMsg(state); }
- :
- KW_JAR -> ^(TOK_JAR)
- |
- KW_FILE -> ^(TOK_FILE)
- |
- KW_ARCHIVE -> ^(TOK_ARCHIVE)
- ;
-
-createFunctionStatement
-@init { pushMsg("create function statement", state); }
-@after { popMsg(state); }
- : KW_CREATE (temp=KW_TEMPORARY)? KW_FUNCTION functionIdentifier KW_AS StringLiteral
- (KW_USING rList=resourceList)?
- -> {$temp != null}? ^(TOK_CREATEFUNCTION functionIdentifier StringLiteral $rList? TOK_TEMPORARY)
- -> ^(TOK_CREATEFUNCTION functionIdentifier StringLiteral $rList?)
- ;
-
-dropFunctionStatement
-@init { pushMsg("drop function statement", state); }
-@after { popMsg(state); }
- : KW_DROP (temp=KW_TEMPORARY)? KW_FUNCTION ifExists? functionIdentifier
- -> {$temp != null}? ^(TOK_DROPFUNCTION functionIdentifier ifExists? TOK_TEMPORARY)
- -> ^(TOK_DROPFUNCTION functionIdentifier ifExists?)
- ;
-
-reloadFunctionStatement
-@init { pushMsg("reload function statement", state); }
-@after { popMsg(state); }
- : KW_RELOAD KW_FUNCTION -> ^(TOK_RELOADFUNCTION);
-
-createMacroStatement
-@init { pushMsg("create macro statement", state); }
-@after { popMsg(state); }
- : KW_CREATE KW_TEMPORARY KW_MACRO Identifier
- LPAREN columnNameTypeList? RPAREN expression
- -> ^(TOK_CREATEMACRO Identifier columnNameTypeList? expression)
- ;
-
-dropMacroStatement
-@init { pushMsg("drop macro statement", state); }
-@after { popMsg(state); }
- : KW_DROP KW_TEMPORARY KW_MACRO ifExists? Identifier
- -> ^(TOK_DROPMACRO Identifier ifExists?)
- ;
-
-createViewStatement
-@init {
- pushMsg("create view statement", state);
-}
-@after { popMsg(state); }
- : KW_CREATE (orReplace)? KW_VIEW (ifNotExists)? name=tableName
- (LPAREN columnNameCommentList RPAREN)? tableComment? viewPartition?
- tablePropertiesPrefixed?
- KW_AS
- selectStatementWithCTE
- -> ^(TOK_CREATEVIEW $name orReplace?
- ifNotExists?
- columnNameCommentList?
- tableComment?
- viewPartition?
- tablePropertiesPrefixed?
- selectStatementWithCTE
- )
- ;
-
-viewPartition
-@init { pushMsg("view partition specification", state); }
-@after { popMsg(state); }
- : KW_PARTITIONED KW_ON LPAREN columnNameList RPAREN
- -> ^(TOK_VIEWPARTCOLS columnNameList)
- ;
-
-dropViewStatement
-@init { pushMsg("drop view statement", state); }
-@after { popMsg(state); }
- : KW_DROP KW_VIEW ifExists? viewName -> ^(TOK_DROPVIEW viewName ifExists?)
- ;
-
-showFunctionIdentifier
-@init { pushMsg("identifier for show function statement", state); }
-@after { popMsg(state); }
- : functionIdentifier
- | StringLiteral
- ;
-
-showStmtIdentifier
-@init { pushMsg("identifier for show statement", state); }
-@after { popMsg(state); }
- : identifier
- | StringLiteral
- ;
-
-tableProvider
-@init { pushMsg("table's provider", state); }
-@after { popMsg(state); }
- :
- KW_USING Identifier (DOT Identifier)*
- -> ^(TOK_TABLEPROVIDER Identifier+)
- ;
-
-optionKeyValue
-@init { pushMsg("table's option specification", state); }
-@after { popMsg(state); }
- :
- (looseIdentifier (DOT looseIdentifier)*) StringLiteral
- -> ^(TOK_TABLEOPTION looseIdentifier+ StringLiteral)
- ;
-
-tableOpts
-@init { pushMsg("table's options", state); }
-@after { popMsg(state); }
- :
- KW_OPTIONS LPAREN optionKeyValue (COMMA optionKeyValue)* RPAREN
- -> ^(TOK_TABLEOPTIONS optionKeyValue+)
- ;
-
-tableComment
-@init { pushMsg("table's comment", state); }
-@after { popMsg(state); }
- :
- KW_COMMENT comment=StringLiteral -> ^(TOK_TABLECOMMENT $comment)
- ;
-
-tablePartition
-@init { pushMsg("table partition specification", state); }
-@after { popMsg(state); }
- : KW_PARTITIONED KW_BY LPAREN columnNameTypeList RPAREN
- -> ^(TOK_TABLEPARTCOLS columnNameTypeList)
- ;
-
-tableBuckets
-@init { pushMsg("table buckets specification", state); }
-@after { popMsg(state); }
- :
- KW_CLUSTERED KW_BY LPAREN bucketCols=columnNameList RPAREN (KW_SORTED KW_BY LPAREN sortCols=columnNameOrderList RPAREN)? KW_INTO num=Number KW_BUCKETS
- -> ^(TOK_ALTERTABLE_BUCKETS $bucketCols $sortCols? $num)
- ;
-
-tableSkewed
-@init { pushMsg("table skewed specification", state); }
-@after { popMsg(state); }
- :
- KW_SKEWED KW_BY LPAREN skewedCols=columnNameList RPAREN KW_ON LPAREN (skewedValues=skewedValueElement) RPAREN ((storedAsDirs) => storedAsDirs)?
- -> ^(TOK_TABLESKEWED $skewedCols $skewedValues storedAsDirs?)
- ;
-
-rowFormat
-@init { pushMsg("serde specification", state); }
-@after { popMsg(state); }
- : rowFormatSerde -> ^(TOK_SERDE rowFormatSerde)
- | rowFormatDelimited -> ^(TOK_SERDE rowFormatDelimited)
- | -> ^(TOK_SERDE)
- ;
-
-recordReader
-@init { pushMsg("record reader specification", state); }
-@after { popMsg(state); }
- : KW_RECORDREADER StringLiteral -> ^(TOK_RECORDREADER StringLiteral)
- | -> ^(TOK_RECORDREADER)
- ;
-
-recordWriter
-@init { pushMsg("record writer specification", state); }
-@after { popMsg(state); }
- : KW_RECORDWRITER StringLiteral -> ^(TOK_RECORDWRITER StringLiteral)
- | -> ^(TOK_RECORDWRITER)
- ;
-
-rowFormatSerde
-@init { pushMsg("serde format specification", state); }
-@after { popMsg(state); }
- : KW_ROW KW_FORMAT KW_SERDE name=StringLiteral (KW_WITH KW_SERDEPROPERTIES serdeprops=tableProperties)?
- -> ^(TOK_SERDENAME $name $serdeprops?)
- ;
-
-rowFormatDelimited
-@init { pushMsg("serde properties specification", state); }
-@after { popMsg(state); }
- :
- KW_ROW KW_FORMAT KW_DELIMITED tableRowFormatFieldIdentifier? tableRowFormatCollItemsIdentifier? tableRowFormatMapKeysIdentifier? tableRowFormatLinesIdentifier? tableRowNullFormat?
- -> ^(TOK_SERDEPROPS tableRowFormatFieldIdentifier? tableRowFormatCollItemsIdentifier? tableRowFormatMapKeysIdentifier? tableRowFormatLinesIdentifier? tableRowNullFormat?)
- ;
-
-tableRowFormat
-@init { pushMsg("table row format specification", state); }
-@after { popMsg(state); }
- :
- rowFormatDelimited
- -> ^(TOK_TABLEROWFORMAT rowFormatDelimited)
- | rowFormatSerde
- -> ^(TOK_TABLESERIALIZER rowFormatSerde)
- ;
-
-tablePropertiesPrefixed
-@init { pushMsg("table properties with prefix", state); }
-@after { popMsg(state); }
- :
- KW_TBLPROPERTIES! tableProperties
- ;
-
-tableProperties
-@init { pushMsg("table properties", state); }
-@after { popMsg(state); }
- :
- LPAREN tablePropertiesList RPAREN -> ^(TOK_TABLEPROPERTIES tablePropertiesList)
- ;
-
-tablePropertiesList
-@init { pushMsg("table properties list", state); }
-@after { popMsg(state); }
- :
- keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_TABLEPROPLIST keyValueProperty+)
- |
- keyProperty (COMMA keyProperty)* -> ^(TOK_TABLEPROPLIST keyProperty+)
- ;
-
-keyValueProperty
-@init { pushMsg("specifying key/value property", state); }
-@after { popMsg(state); }
- :
- key=StringLiteral EQUAL value=StringLiteral -> ^(TOK_TABLEPROPERTY $key $value)
- ;
-
-keyProperty
-@init { pushMsg("specifying key property", state); }
-@after { popMsg(state); }
- :
- key=StringLiteral -> ^(TOK_TABLEPROPERTY $key TOK_NULL)
- ;
-
-tableRowFormatFieldIdentifier
-@init { pushMsg("table row format's field separator", state); }
-@after { popMsg(state); }
- :
- KW_FIELDS KW_TERMINATED KW_BY fldIdnt=StringLiteral (KW_ESCAPED KW_BY fldEscape=StringLiteral)?
- -> ^(TOK_TABLEROWFORMATFIELD $fldIdnt $fldEscape?)
- ;
-
-tableRowFormatCollItemsIdentifier
-@init { pushMsg("table row format's column separator", state); }
-@after { popMsg(state); }
- :
- KW_COLLECTION KW_ITEMS KW_TERMINATED KW_BY collIdnt=StringLiteral
- -> ^(TOK_TABLEROWFORMATCOLLITEMS $collIdnt)
- ;
-
-tableRowFormatMapKeysIdentifier
-@init { pushMsg("table row format's map key separator", state); }
-@after { popMsg(state); }
- :
- KW_MAP KW_KEYS KW_TERMINATED KW_BY mapKeysIdnt=StringLiteral
- -> ^(TOK_TABLEROWFORMATMAPKEYS $mapKeysIdnt)
- ;
-
-tableRowFormatLinesIdentifier
-@init { pushMsg("table row format's line separator", state); }
-@after { popMsg(state); }
- :
- KW_LINES KW_TERMINATED KW_BY linesIdnt=StringLiteral
- -> ^(TOK_TABLEROWFORMATLINES $linesIdnt)
- ;
-
-tableRowNullFormat
-@init { pushMsg("table row format's null specifier", state); }
-@after { popMsg(state); }
- :
- KW_NULL KW_DEFINED KW_AS nullIdnt=StringLiteral
- -> ^(TOK_TABLEROWFORMATNULL $nullIdnt)
- ;
-tableFileFormat
-@init { pushMsg("table file format specification", state); }
-@after { popMsg(state); }
- :
- (KW_STORED KW_AS KW_INPUTFORMAT) => KW_STORED KW_AS KW_INPUTFORMAT inFmt=StringLiteral KW_OUTPUTFORMAT outFmt=StringLiteral (KW_INPUTDRIVER inDriver=StringLiteral KW_OUTPUTDRIVER outDriver=StringLiteral)?
- -> ^(TOK_TABLEFILEFORMAT $inFmt $outFmt $inDriver? $outDriver?)
- | KW_STORED KW_BY storageHandler=StringLiteral
- (KW_WITH KW_SERDEPROPERTIES serdeprops=tableProperties)?
- -> ^(TOK_STORAGEHANDLER $storageHandler $serdeprops?)
- | KW_STORED KW_AS genericSpec=identifier
- -> ^(TOK_FILEFORMAT_GENERIC $genericSpec)
- ;
-
-tableLocation
-@init { pushMsg("table location specification", state); }
-@after { popMsg(state); }
- :
- KW_LOCATION locn=StringLiteral -> ^(TOK_TABLELOCATION $locn)
- ;
-
-columnNameTypeList
-@init { pushMsg("column name type list", state); }
-@after { popMsg(state); }
- : columnNameType (COMMA columnNameType)* -> ^(TOK_TABCOLLIST columnNameType+)
- ;
-
-columnNameColonTypeList
-@init { pushMsg("column name type list", state); }
-@after { popMsg(state); }
- : columnNameColonType (COMMA columnNameColonType)* -> ^(TOK_TABCOLLIST columnNameColonType+)
- ;
-
-columnNameList
-@init { pushMsg("column name list", state); }
-@after { popMsg(state); }
- : columnName (COMMA columnName)* -> ^(TOK_TABCOLNAME columnName+)
- ;
-
-columnName
-@init { pushMsg("column name", state); }
-@after { popMsg(state); }
- :
- identifier
- ;
-
-extColumnName
-@init { pushMsg("column name for complex types", state); }
-@after { popMsg(state); }
- :
- identifier (DOT^ ((KW_ELEM_TYPE) => KW_ELEM_TYPE | (KW_KEY_TYPE) => KW_KEY_TYPE | (KW_VALUE_TYPE) => KW_VALUE_TYPE | identifier))*
- ;
-
-columnNameOrderList
-@init { pushMsg("column name order list", state); }
-@after { popMsg(state); }
- : columnNameOrder (COMMA columnNameOrder)* -> ^(TOK_TABCOLNAME columnNameOrder+)
- ;
-
-skewedValueElement
-@init { pushMsg("skewed value element", state); }
-@after { popMsg(state); }
- :
- skewedColumnValues
- | skewedColumnValuePairList
- ;
-
-skewedColumnValuePairList
-@init { pushMsg("column value pair list", state); }
-@after { popMsg(state); }
- : skewedColumnValuePair (COMMA skewedColumnValuePair)* -> ^(TOK_TABCOLVALUE_PAIR skewedColumnValuePair+)
- ;
-
-skewedColumnValuePair
-@init { pushMsg("column value pair", state); }
-@after { popMsg(state); }
- :
- LPAREN colValues=skewedColumnValues RPAREN
- -> ^(TOK_TABCOLVALUES $colValues)
- ;
-
-skewedColumnValues
-@init { pushMsg("column values", state); }
-@after { popMsg(state); }
- : skewedColumnValue (COMMA skewedColumnValue)* -> ^(TOK_TABCOLVALUE skewedColumnValue+)
- ;
-
-skewedColumnValue
-@init { pushMsg("column value", state); }
-@after { popMsg(state); }
- :
- constant
- ;
-
-skewedValueLocationElement
-@init { pushMsg("skewed value location element", state); }
-@after { popMsg(state); }
- :
- skewedColumnValue
- | skewedColumnValuePair
- ;
-
-columnNameOrder
-@init { pushMsg("column name order", state); }
-@after { popMsg(state); }
- : identifier (asc=KW_ASC | desc=KW_DESC)?
- -> {$desc == null}? ^(TOK_TABSORTCOLNAMEASC identifier)
- -> ^(TOK_TABSORTCOLNAMEDESC identifier)
- ;
-
-columnNameCommentList
-@init { pushMsg("column name comment list", state); }
-@after { popMsg(state); }
- : columnNameComment (COMMA columnNameComment)* -> ^(TOK_TABCOLNAME columnNameComment+)
- ;
-
-columnNameComment
-@init { pushMsg("column name comment", state); }
-@after { popMsg(state); }
- : colName=identifier (KW_COMMENT comment=StringLiteral)?
- -> ^(TOK_TABCOL $colName TOK_NULL $comment?)
- ;
-
-columnRefOrder
-@init { pushMsg("column order", state); }
-@after { popMsg(state); }
- : expression (asc=KW_ASC | desc=KW_DESC)?
- -> {$desc == null}? ^(TOK_TABSORTCOLNAMEASC expression)
- -> ^(TOK_TABSORTCOLNAMEDESC expression)
- ;
-
-columnNameType
-@init { pushMsg("column specification", state); }
-@after { popMsg(state); }
- : colName=identifier colType (KW_COMMENT comment=StringLiteral)?
- -> {containExcludedCharForCreateTableColumnName($colName.text)}? {throwColumnNameException()}
- -> {$comment == null}? ^(TOK_TABCOL $colName colType)
- -> ^(TOK_TABCOL $colName colType $comment)
- ;
-
-columnNameColonType
-@init { pushMsg("column specification", state); }
-@after { popMsg(state); }
- : colName=identifier COLON colType (KW_COMMENT comment=StringLiteral)?
- -> {$comment == null}? ^(TOK_TABCOL $colName colType)
- -> ^(TOK_TABCOL $colName colType $comment)
- ;
-
-colType
-@init { pushMsg("column type", state); }
-@after { popMsg(state); }
- : type
- ;
-
-colTypeList
-@init { pushMsg("column type list", state); }
-@after { popMsg(state); }
- : colType (COMMA colType)* -> ^(TOK_COLTYPELIST colType+)
- ;
-
-type
- : primitiveType
- | listType
- | structType
- | mapType
- | unionType;
-
-primitiveType
-@init { pushMsg("primitive type specification", state); }
-@after { popMsg(state); }
- : KW_TINYINT -> TOK_TINYINT
- | KW_SMALLINT -> TOK_SMALLINT
- | KW_INT -> TOK_INT
- | KW_BIGINT -> TOK_BIGINT
- | KW_LONG -> TOK_BIGINT
- | KW_BOOLEAN -> TOK_BOOLEAN
- | KW_FLOAT -> TOK_FLOAT
- | KW_DOUBLE -> TOK_DOUBLE
- | KW_DATE -> TOK_DATE
- | KW_DATETIME -> TOK_DATETIME
- | KW_TIMESTAMP -> TOK_TIMESTAMP
- // Uncomment to allow intervals as table column types
- //| KW_INTERVAL KW_YEAR KW_TO KW_MONTH -> TOK_INTERVAL_YEAR_MONTH
- //| KW_INTERVAL KW_DAY KW_TO KW_SECOND -> TOK_INTERVAL_DAY_TIME
- | KW_STRING -> TOK_STRING
- | KW_BINARY -> TOK_BINARY
- | KW_DECIMAL (LPAREN prec=Number (COMMA scale=Number)? RPAREN)? -> ^(TOK_DECIMAL $prec? $scale?)
- | KW_VARCHAR LPAREN length=Number RPAREN -> ^(TOK_VARCHAR $length)
- | KW_CHAR LPAREN length=Number RPAREN -> ^(TOK_CHAR $length)
- ;
-
-listType
-@init { pushMsg("list type", state); }
-@after { popMsg(state); }
- : KW_ARRAY LESSTHAN type GREATERTHAN -> ^(TOK_LIST type)
- ;
-
-structType
-@init { pushMsg("struct type", state); }
-@after { popMsg(state); }
- : KW_STRUCT LESSTHAN columnNameColonTypeList GREATERTHAN -> ^(TOK_STRUCT columnNameColonTypeList)
- ;
-
-mapType
-@init { pushMsg("map type", state); }
-@after { popMsg(state); }
- : KW_MAP LESSTHAN left=type COMMA right=type GREATERTHAN
- -> ^(TOK_MAP $left $right)
- ;
-
-unionType
-@init { pushMsg("uniontype type", state); }
-@after { popMsg(state); }
- : KW_UNIONTYPE LESSTHAN colTypeList GREATERTHAN -> ^(TOK_UNIONTYPE colTypeList)
- ;
-
-setOperator
-@init { pushMsg("set operator", state); }
-@after { popMsg(state); }
- : KW_UNION KW_ALL -> ^(TOK_UNIONALL)
- | KW_UNION KW_DISTINCT? -> ^(TOK_UNIONDISTINCT)
- | KW_EXCEPT -> ^(TOK_EXCEPT)
- | KW_INTERSECT -> ^(TOK_INTERSECT)
- ;
-
-queryStatementExpression[boolean topLevel]
- :
- /* Would be nice to do this as a gated semantic perdicate
- But the predicate gets pushed as a lookahead decision.
- Calling rule doesnot know about topLevel
- */
- (w=withClause {topLevel}?)?
- queryStatementExpressionBody[topLevel] {
- if ($w.tree != null) {
- $queryStatementExpressionBody.tree.insertChild(0, $w.tree);
- }
- }
- -> queryStatementExpressionBody
- ;
-
-queryStatementExpressionBody[boolean topLevel]
- :
- fromStatement[topLevel]
- | regularBody[topLevel]
- ;
-
-withClause
- :
- KW_WITH cteStatement (COMMA cteStatement)* -> ^(TOK_CTE cteStatement+)
-;
-
-cteStatement
- :
- identifier KW_AS LPAREN queryStatementExpression[false] RPAREN
- -> ^(TOK_SUBQUERY queryStatementExpression identifier)
-;
-
-fromStatement[boolean topLevel]
-: (singleFromStatement -> singleFromStatement)
- (u=setOperator r=singleFromStatement
- -> ^($u {$fromStatement.tree} $r)
- )*
- -> {u != null && topLevel}? ^(TOK_QUERY
- ^(TOK_FROM
- ^(TOK_SUBQUERY
- {$fromStatement.tree}
- {adaptor.create(Identifier, generateUnionAlias())}
- )
- )
- ^(TOK_INSERT
- ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE))
- ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF))
- )
- )
- -> {$fromStatement.tree}
- ;
-
-
-singleFromStatement
- :
- fromClause
- ( b+=body )+ -> ^(TOK_QUERY fromClause body+)
- ;
-
-/*
-The valuesClause rule below ensures that the parse tree for
-"insert into table FOO values (1,2),(3,4)" looks the same as
-"insert into table FOO select a,b from (values(1,2),(3,4)) as BAR(a,b)" which itself is made to look
-very similar to the tree for "insert into table FOO select a,b from BAR". Since virtual table name
-is implicit, it's represented as TOK_ANONYMOUS.
-*/
-regularBody[boolean topLevel]
- :
- i=insertClause
- (
- s=selectStatement[topLevel]
- {$s.tree.getFirstChildWithType(TOK_INSERT).replaceChildren(0, 0, $i.tree);} -> {$s.tree}
- |
- valuesClause
- -> ^(TOK_QUERY
- ^(TOK_FROM
- ^(TOK_VIRTUAL_TABLE ^(TOK_VIRTUAL_TABREF ^(TOK_ANONYMOUS)) valuesClause)
- )
- ^(TOK_INSERT {$i.tree} ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF)))
- )
- )
- |
- selectStatement[topLevel]
- ;
-
-selectStatement[boolean topLevel]
- :
- (
- (
- LPAREN
- s=selectClause
- f=fromClause?
- w=whereClause?
- g=groupByClause?
- h=havingClause?
- o=orderByClause?
- c=clusterByClause?
- d=distributeByClause?
- sort=sortByClause?
- win=window_clause?
- l=limitClause?
- RPAREN
- |
- s=selectClause
- f=fromClause?
- w=whereClause?
- g=groupByClause?
- h=havingClause?
- o=orderByClause?
- c=clusterByClause?
- d=distributeByClause?
- sort=sortByClause?
- win=window_clause?
- l=limitClause?
- )
- -> ^(TOK_QUERY $f? ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE))
- $s $w? $g? $h? $o? $c?
- $d? $sort? $win? $l?))
- )
- (set=setOpSelectStatement[$selectStatement.tree, topLevel])?
- -> {set == null}?
- {$selectStatement.tree}
- -> {o==null && c==null && d==null && sort==null && l==null}?
- {$set.tree}
- -> {throwSetOpException()}
- ;
-
-setOpSelectStatement[CommonTree t, boolean topLevel]
- :
- ((
- u=setOperator LPAREN b=simpleSelectStatement RPAREN
- |
- u=setOperator b=simpleSelectStatement)
- -> {$setOpSelectStatement.tree != null}?
- ^($u {$setOpSelectStatement.tree} $b)
- -> ^($u {$t} $b)
- )+
- o=orderByClause?
- c=clusterByClause?
- d=distributeByClause?
- sort=sortByClause?
- win=window_clause?
- l=limitClause?
- -> {o==null && c==null && d==null && sort==null && win==null && l==null && !topLevel}?
- {$setOpSelectStatement.tree}
- -> ^(TOK_QUERY
- ^(TOK_FROM
- ^(TOK_SUBQUERY
- {$setOpSelectStatement.tree}
- {adaptor.create(Identifier, generateUnionAlias())}
- )
- )
- ^(TOK_INSERT
- ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE))
- ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF))
- $o? $c? $d? $sort? $win? $l?
- )
- )
- ;
-
-simpleSelectStatement
- :
- selectClause
- fromClause?
- whereClause?
- groupByClause?
- havingClause?
- ((window_clause) => window_clause)?
- -> ^(TOK_QUERY fromClause? ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE))
- selectClause whereClause? groupByClause? havingClause? window_clause?))
- ;
-
-selectStatementWithCTE
- :
- (w=withClause)?
- selectStatement[true] {
- if ($w.tree != null) {
- $selectStatement.tree.insertChild(0, $w.tree);
- }
- }
- -> selectStatement
- ;
-
-body
- :
- insertClause
- selectClause
- lateralView?
- whereClause?
- groupByClause?
- havingClause?
- orderByClause?
- clusterByClause?
- distributeByClause?
- sortByClause?
- window_clause?
- limitClause? -> ^(TOK_INSERT insertClause
- selectClause lateralView? whereClause? groupByClause? havingClause? orderByClause? clusterByClause?
- distributeByClause? sortByClause? window_clause? limitClause?)
- |
- selectClause
- lateralView?
- whereClause?
- groupByClause?
- havingClause?
- orderByClause?
- clusterByClause?
- distributeByClause?
- sortByClause?
- window_clause?
- limitClause? -> ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE))
- selectClause lateralView? whereClause? groupByClause? havingClause? orderByClause? clusterByClause?
- distributeByClause? sortByClause? window_clause? limitClause?)
- ;
-
-insertClause
-@init { pushMsg("insert clause", state); }
-@after { popMsg(state); }
- :
- KW_INSERT KW_OVERWRITE destination ifNotExists? -> ^(TOK_DESTINATION destination ifNotExists?)
- | KW_INSERT KW_INTO KW_TABLE? tableOrPartition (LPAREN targetCols=columnNameList RPAREN)?
- -> ^(TOK_INSERT_INTO tableOrPartition $targetCols?)
- ;
-
-destination
-@init { pushMsg("destination specification", state); }
-@after { popMsg(state); }
- :
- (local = KW_LOCAL)? KW_DIRECTORY StringLiteral tableRowFormat? tableFileFormat?
- -> ^(TOK_DIR StringLiteral $local? tableRowFormat? tableFileFormat?)
- | KW_TABLE tableOrPartition -> tableOrPartition
- ;
-
-limitClause
-@init { pushMsg("limit clause", state); }
-@after { popMsg(state); }
- :
- KW_LIMIT num=Number -> ^(TOK_LIMIT $num)
- ;
-
-//DELETE FROM <tableName> WHERE ...;
-deleteStatement
-@init { pushMsg("delete statement", state); }
-@after { popMsg(state); }
- :
- KW_DELETE KW_FROM tableName (whereClause)? -> ^(TOK_DELETE_FROM tableName whereClause?)
- ;
-
-/*SET <columName> = (3 + col2)*/
-columnAssignmentClause
- :
- tableOrColumn EQUAL^ precedencePlusExpression
- ;
-
-/*SET col1 = 5, col2 = (4 + col4), ...*/
-setColumnsClause
- :
- KW_SET columnAssignmentClause (COMMA columnAssignmentClause)* -> ^(TOK_SET_COLUMNS_CLAUSE columnAssignmentClause* )
- ;
-
-/*
- UPDATE <table>
- SET col1 = val1, col2 = val2... WHERE ...
-*/
-updateStatement
-@init { pushMsg("update statement", state); }
-@after { popMsg(state); }
- :
- KW_UPDATE tableName setColumnsClause whereClause? -> ^(TOK_UPDATE_TABLE tableName setColumnsClause whereClause?)
- ;
-
-/*
-BEGIN user defined transaction boundaries; follows SQL 2003 standard exactly except for addition of
-"setAutoCommitStatement" which is not in the standard doc but is supported by most SQL engines.
-*/
-sqlTransactionStatement
-@init { pushMsg("transaction statement", state); }
-@after { popMsg(state); }
- : startTransactionStatement
- | commitStatement
- | rollbackStatement
- | setAutoCommitStatement
- ;
-
-startTransactionStatement
- :
- KW_START KW_TRANSACTION ( transactionMode ( COMMA transactionMode )* )? -> ^(TOK_START_TRANSACTION transactionMode*)
- ;
-
-transactionMode
- :
- isolationLevel
- | transactionAccessMode -> ^(TOK_TXN_ACCESS_MODE transactionAccessMode)
- ;
-
-transactionAccessMode
- :
- KW_READ KW_ONLY -> TOK_TXN_READ_ONLY
- | KW_READ KW_WRITE -> TOK_TXN_READ_WRITE
- ;
-
-isolationLevel
- :
- KW_ISOLATION KW_LEVEL levelOfIsolation -> ^(TOK_ISOLATION_LEVEL levelOfIsolation)
- ;
-
-/*READ UNCOMMITTED | READ COMMITTED | REPEATABLE READ | SERIALIZABLE may be supported later*/
-levelOfIsolation
- :
- KW_SNAPSHOT -> TOK_ISOLATION_SNAPSHOT
- ;
-
-commitStatement
- :
- KW_COMMIT ( KW_WORK )? -> TOK_COMMIT
- ;
-
-rollbackStatement
- :
- KW_ROLLBACK ( KW_WORK )? -> TOK_ROLLBACK
- ;
-setAutoCommitStatement
- :
- KW_SET KW_AUTOCOMMIT booleanValueTok -> ^(TOK_SET_AUTOCOMMIT booleanValueTok)
- ;
-/*
-END user defined transaction boundaries
-*/
-
-/*
-Table Caching statements.
- */
-cacheStatement
-@init { pushMsg("cache statement", state); }
-@after { popMsg(state); }
- :
- cacheTableStatement
- | uncacheTableStatement
- | clearCacheStatement
- ;
-
-cacheTableStatement
- :
- KW_CACHE (lazy=KW_LAZY)? KW_TABLE identifier (KW_AS selectStatementWithCTE)? -> ^(TOK_CACHETABLE identifier $lazy? selectStatementWithCTE?)
- ;
-
-uncacheTableStatement
- :
- KW_UNCACHE KW_TABLE identifier -> ^(TOK_UNCACHETABLE identifier)
- ;
-
-clearCacheStatement
- :
- KW_CLEAR KW_CACHE -> ^(TOK_CLEARCACHE)
- ;
-
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
new file mode 100644
index 0000000000..9cf2dd257e
--- /dev/null
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -0,0 +1,957 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * This file is an adaptation of Presto's presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 grammar.
+ */
+
+grammar SqlBase;
+
+tokens {
+ DELIMITER
+}
+
+singleStatement
+ : statement EOF
+ ;
+
+singleExpression
+ : namedExpression EOF
+ ;
+
+singleTableIdentifier
+ : tableIdentifier EOF
+ ;
+
+singleDataType
+ : dataType EOF
+ ;
+
+statement
+ : query #statementDefault
+ | USE db=identifier #use
+ | CREATE DATABASE (IF NOT EXISTS)? identifier
+ (COMMENT comment=STRING)? locationSpec?
+ (WITH DBPROPERTIES tablePropertyList)? #createDatabase
+ | ALTER DATABASE identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties
+ | DROP DATABASE (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase
+ | createTableHeader ('(' colTypeList ')')? tableProvider
+ (OPTIONS tablePropertyList)? #createTableUsing
+ | createTableHeader tableProvider
+ (OPTIONS tablePropertyList)? AS? query #createTableUsing
+ | createTableHeader ('(' columns=colTypeList ')')?
+ (COMMENT STRING)?
+ (PARTITIONED BY '(' partitionColumns=colTypeList ')')?
+ bucketSpec? skewSpec?
+ rowFormat? createFileFormat? locationSpec?
+ (TBLPROPERTIES tablePropertyList)?
+ (AS? query)? #createTable
+ | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier
+ LIKE source=tableIdentifier #createTableLike
+ | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS
+ (identifier | FOR COLUMNS identifierSeq?)? #analyze
+ | ALTER (TABLE | VIEW) from=tableIdentifier
+ RENAME TO to=tableIdentifier #renameTable
+ | ALTER (TABLE | VIEW) tableIdentifier
+ SET TBLPROPERTIES tablePropertyList #setTableProperties
+ | ALTER (TABLE | VIEW) tableIdentifier
+ UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList #unsetTableProperties
+ | ALTER TABLE tableIdentifier (partitionSpec)?
+ SET SERDE STRING (WITH SERDEPROPERTIES tablePropertyList)? #setTableSerDe
+ | ALTER TABLE tableIdentifier (partitionSpec)?
+ SET SERDEPROPERTIES tablePropertyList #setTableSerDe
+ | ALTER TABLE tableIdentifier bucketSpec #bucketTable
+ | ALTER TABLE tableIdentifier NOT CLUSTERED #unclusterTable
+ | ALTER TABLE tableIdentifier NOT SORTED #unsortTable
+ | ALTER TABLE tableIdentifier skewSpec #skewTable
+ | ALTER TABLE tableIdentifier NOT SKEWED #unskewTable
+ | ALTER TABLE tableIdentifier NOT STORED AS DIRECTORIES #unstoreTable
+ | ALTER TABLE tableIdentifier
+ SET SKEWED LOCATION skewedLocationList #setTableSkewLocations
+ | ALTER TABLE tableIdentifier ADD (IF NOT EXISTS)?
+ partitionSpecLocation+ #addTablePartition
+ | ALTER VIEW tableIdentifier ADD (IF NOT EXISTS)?
+ partitionSpec+ #addTablePartition
+ | ALTER TABLE tableIdentifier
+ from=partitionSpec RENAME TO to=partitionSpec #renameTablePartition
+ | ALTER TABLE from=tableIdentifier
+ EXCHANGE partitionSpec WITH TABLE to=tableIdentifier #exchangeTablePartition
+ | ALTER TABLE tableIdentifier
+ DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? #dropTablePartitions
+ | ALTER VIEW tableIdentifier
+ DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* #dropTablePartitions
+ | ALTER TABLE tableIdentifier ARCHIVE partitionSpec #archiveTablePartition
+ | ALTER TABLE tableIdentifier UNARCHIVE partitionSpec #unarchiveTablePartition
+ | ALTER TABLE tableIdentifier partitionSpec?
+ SET FILEFORMAT fileFormat #setTableFileFormat
+ | ALTER TABLE tableIdentifier partitionSpec? SET locationSpec #setTableLocation
+ | ALTER TABLE tableIdentifier TOUCH partitionSpec? #touchTable
+ | ALTER TABLE tableIdentifier partitionSpec? COMPACT STRING #compactTable
+ | ALTER TABLE tableIdentifier partitionSpec? CONCATENATE #concatenateTable
+ | ALTER TABLE tableIdentifier partitionSpec?
+ CHANGE COLUMN? oldName=identifier colType
+ (FIRST | AFTER after=identifier)? (CASCADE | RESTRICT)? #changeColumn
+ | ALTER TABLE tableIdentifier partitionSpec?
+ ADD COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #addColumns
+ | ALTER TABLE tableIdentifier partitionSpec?
+ REPLACE COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #replaceColumns
+ | DROP TABLE (IF EXISTS)? tableIdentifier PURGE?
+ (FOR METADATA? REPLICATION '(' STRING ')')? #dropTable
+ | DROP VIEW (IF EXISTS)? tableIdentifier #dropTable
+ | CREATE (OR REPLACE)? VIEW (IF NOT EXISTS)? tableIdentifier
+ identifierCommentList? (COMMENT STRING)?
+ (PARTITIONED ON identifierList)?
+ (TBLPROPERTIES tablePropertyList)? AS query #createView
+ | ALTER VIEW tableIdentifier AS? query #alterViewQuery
+ | CREATE TEMPORARY? FUNCTION qualifiedName AS className=STRING
+ (USING resource (',' resource)*)? #createFunction
+ | DROP TEMPORARY? FUNCTION (IF EXISTS)? qualifiedName #dropFunction
+ | EXPLAIN explainOption* statement #explain
+ | SHOW TABLES ((FROM | IN) db=identifier)?
+ (LIKE? pattern=STRING)? #showTables
+ | SHOW DATABASES (LIKE pattern=STRING)? #showDatabases
+ | SHOW TBLPROPERTIES table=tableIdentifier
+ ('(' key=tablePropertyKey ')')? #showTblProperties
+ | SHOW FUNCTIONS (LIKE? (qualifiedName | pattern=STRING))? #showFunctions
+ | (DESC | DESCRIBE) FUNCTION EXTENDED? qualifiedName #describeFunction
+ | (DESC | DESCRIBE) option=(EXTENDED | FORMATTED)?
+ tableIdentifier partitionSpec? describeColName? #describeTable
+ | (DESC | DESCRIBE) DATABASE EXTENDED? identifier #describeDatabase
+ | REFRESH TABLE tableIdentifier #refreshTable
+ | CACHE LAZY? TABLE identifier (AS? query)? #cacheTable
+ | UNCACHE TABLE identifier #uncacheTable
+ | CLEAR CACHE #clearCache
+ | ADD identifier .*? #addResource
+ | SET ROLE .*? #failNativeCommand
+ | SET .*? #setConfiguration
+ | kws=unsupportedHiveNativeCommands .*? #failNativeCommand
+ | hiveNativeCommands #executeNativeCommand
+ ;
+
+hiveNativeCommands
+ : DELETE FROM tableIdentifier (WHERE booleanExpression)?
+ | TRUNCATE TABLE tableIdentifier partitionSpec?
+ (COLUMNS identifierList)?
+ | SHOW COLUMNS (FROM | IN) tableIdentifier ((FROM|IN) identifier)?
+ | START TRANSACTION (transactionMode (',' transactionMode)*)?
+ | COMMIT WORK?
+ | ROLLBACK WORK?
+ | SHOW PARTITIONS tableIdentifier partitionSpec?
+ | DFS .*?
+ | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | LOAD) .*?
+ ;
+
+unsupportedHiveNativeCommands
+ : kw1=CREATE kw2=ROLE
+ | kw1=DROP kw2=ROLE
+ | kw1=GRANT kw2=ROLE?
+ | kw1=REVOKE kw2=ROLE?
+ | kw1=SHOW kw2=GRANT
+ | kw1=SHOW kw2=ROLE kw3=GRANT?
+ | kw1=SHOW kw2=PRINCIPALS
+ | kw1=SHOW kw2=ROLES
+ | kw1=SHOW kw2=CURRENT kw3=ROLES
+ | kw1=EXPORT kw2=TABLE
+ | kw1=IMPORT kw2=TABLE
+ | kw1=SHOW kw2=COMPACTIONS
+ | kw1=SHOW kw2=CREATE kw3=TABLE
+ | kw1=SHOW kw2=TRANSACTIONS
+ | kw1=SHOW kw2=INDEXES
+ | kw1=SHOW kw2=LOCKS
+ | kw1=CREATE kw2=INDEX
+ | kw1=DROP kw2=INDEX
+ | kw1=ALTER kw2=INDEX
+ | kw1=LOCK kw2=TABLE
+ | kw1=LOCK kw2=DATABASE
+ | kw1=UNLOCK kw2=TABLE
+ | kw1=UNLOCK kw2=DATABASE
+ | kw1=CREATE kw2=TEMPORARY kw3=MACRO
+ | kw1=DROP kw2=TEMPORARY kw3=MACRO
+ | kw1=MSCK kw2=REPAIR kw3=TABLE
+ ;
+
+createTableHeader
+ : CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? tableIdentifier
+ ;
+
+bucketSpec
+ : CLUSTERED BY identifierList
+ (SORTED BY orderedIdentifierList)?
+ INTO INTEGER_VALUE BUCKETS
+ ;
+
+skewSpec
+ : SKEWED BY identifierList
+ ON (constantList | nestedConstantList)
+ (STORED AS DIRECTORIES)?
+ ;
+
+locationSpec
+ : LOCATION STRING
+ ;
+
+query
+ : ctes? queryNoWith
+ ;
+
+insertInto
+ : INSERT OVERWRITE TABLE tableIdentifier partitionSpec? (IF NOT EXISTS)?
+ | INSERT INTO TABLE? tableIdentifier partitionSpec?
+ ;
+
+partitionSpecLocation
+ : partitionSpec locationSpec?
+ ;
+
+partitionSpec
+ : PARTITION '(' partitionVal (',' partitionVal)* ')'
+ ;
+
+partitionVal
+ : identifier (EQ constant)?
+ ;
+
+describeColName
+ : identifier ('.' (identifier | STRING))*
+ ;
+
+ctes
+ : WITH namedQuery (',' namedQuery)*
+ ;
+
+namedQuery
+ : name=identifier AS? '(' queryNoWith ')'
+ ;
+
+tableProvider
+ : USING qualifiedName
+ ;
+
+tablePropertyList
+ : '(' tableProperty (',' tableProperty)* ')'
+ ;
+
+tableProperty
+ : key=tablePropertyKey (EQ? value=STRING)?
+ ;
+
+tablePropertyKey
+ : looseIdentifier ('.' looseIdentifier)*
+ | STRING
+ ;
+
+constantList
+ : '(' constant (',' constant)* ')'
+ ;
+
+nestedConstantList
+ : '(' constantList (',' constantList)* ')'
+ ;
+
+skewedLocation
+ : (constant | constantList) EQ STRING
+ ;
+
+skewedLocationList
+ : '(' skewedLocation (',' skewedLocation)* ')'
+ ;
+
+createFileFormat
+ : STORED AS fileFormat
+ | STORED BY storageHandler
+ ;
+
+fileFormat
+ : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING (SERDE serdeCls=STRING)? #tableFileFormat
+ | identifier #genericFileFormat
+ ;
+
+storageHandler
+ : STRING (WITH SERDEPROPERTIES tablePropertyList)?
+ ;
+
+resource
+ : identifier STRING
+ ;
+
+queryNoWith
+ : insertInto? queryTerm queryOrganization #singleInsertQuery
+ | fromClause multiInsertQueryBody+ #multiInsertQuery
+ ;
+
+queryOrganization
+ : (ORDER BY order+=sortItem (',' order+=sortItem)*)?
+ (CLUSTER BY clusterBy+=expression (',' clusterBy+=expression)*)?
+ (DISTRIBUTE BY distributeBy+=expression (',' distributeBy+=expression)*)?
+ (SORT BY sort+=sortItem (',' sort+=sortItem)*)?
+ windows?
+ (LIMIT limit=expression)?
+ ;
+
+multiInsertQueryBody
+ : insertInto?
+ querySpecification
+ queryOrganization
+ ;
+
+queryTerm
+ : queryPrimary #queryTermDefault
+ | left=queryTerm operator=(INTERSECT | UNION | EXCEPT) setQuantifier? right=queryTerm #setOperation
+ ;
+
+queryPrimary
+ : querySpecification #queryPrimaryDefault
+ | TABLE tableIdentifier #table
+ | inlineTable #inlineTableDefault1
+ | '(' queryNoWith ')' #subquery
+ ;
+
+sortItem
+ : expression ordering=(ASC | DESC)?
+ ;
+
+querySpecification
+ : (((SELECT kind=TRANSFORM '(' namedExpressionSeq ')'
+ | kind=MAP namedExpressionSeq
+ | kind=REDUCE namedExpressionSeq))
+ inRowFormat=rowFormat?
+ (RECORDWRITER recordWriter=STRING)?
+ USING script=STRING
+ (AS (identifierSeq | colTypeList | ('(' (identifierSeq | colTypeList) ')')))?
+ outRowFormat=rowFormat?
+ (RECORDREADER recordReader=STRING)?
+ fromClause?
+ (WHERE where=booleanExpression)?)
+ | ((kind=SELECT setQuantifier? namedExpressionSeq fromClause?
+ | fromClause (kind=SELECT setQuantifier? namedExpressionSeq)?)
+ lateralView*
+ (WHERE where=booleanExpression)?
+ aggregation?
+ (HAVING having=booleanExpression)?
+ windows?)
+ ;
+
+fromClause
+ : FROM relation (',' relation)* lateralView*
+ ;
+
+aggregation
+ : GROUP BY groupingExpressions+=expression (',' groupingExpressions+=expression)* (
+ WITH kind=ROLLUP
+ | WITH kind=CUBE
+ | kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')')?
+ ;
+
+groupingSet
+ : '(' (expression (',' expression)*)? ')'
+ | expression
+ ;
+
+lateralView
+ : LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)?
+ ;
+
+setQuantifier
+ : DISTINCT
+ | ALL
+ ;
+
+relation
+ : left=relation
+ ((CROSS | joinType) JOIN right=relation joinCriteria?
+ | NATURAL joinType JOIN right=relation
+ ) #joinRelation
+ | relationPrimary #relationDefault
+ ;
+
+joinType
+ : INNER?
+ | LEFT OUTER?
+ | LEFT SEMI
+ | RIGHT OUTER?
+ | FULL OUTER?
+ | LEFT? ANTI
+ ;
+
+joinCriteria
+ : ON booleanExpression
+ | USING '(' identifier (',' identifier)* ')'
+ ;
+
+sample
+ : TABLESAMPLE '('
+ ( (percentage=(INTEGER_VALUE | DECIMAL_VALUE) sampleType=PERCENTLIT)
+ | (expression sampleType=ROWS)
+ | (sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE (ON identifier)?))
+ ')'
+ ;
+
+identifierList
+ : '(' identifierSeq ')'
+ ;
+
+identifierSeq
+ : identifier (',' identifier)*
+ ;
+
+orderedIdentifierList
+ : '(' orderedIdentifier (',' orderedIdentifier)* ')'
+ ;
+
+orderedIdentifier
+ : identifier ordering=(ASC | DESC)?
+ ;
+
+identifierCommentList
+ : '(' identifierComment (',' identifierComment)* ')'
+ ;
+
+identifierComment
+ : identifier (COMMENT STRING)?
+ ;
+
+relationPrimary
+ : tableIdentifier sample? (AS? identifier)? #tableName
+ | '(' queryNoWith ')' sample? (AS? identifier)? #aliasedQuery
+ | '(' relation ')' sample? (AS? identifier)? #aliasedRelation
+ | inlineTable #inlineTableDefault2
+ ;
+
+inlineTable
+ : VALUES expression (',' expression)* (AS? identifier identifierList?)?
+ ;
+
+rowFormat
+ : ROW FORMAT SERDE name=STRING (WITH SERDEPROPERTIES props=tablePropertyList)? #rowFormatSerde
+ | ROW FORMAT DELIMITED
+ (FIELDS TERMINATED BY fieldsTerminatedBy=STRING (ESCAPED BY escapedBy=STRING)?)?
+ (COLLECTION ITEMS TERMINATED BY collectionItemsTerminatedBy=STRING)?
+ (MAP KEYS TERMINATED BY keysTerminatedBy=STRING)?
+ (LINES TERMINATED BY linesSeparatedBy=STRING)?
+ (NULL DEFINED AS nullDefinedAs=STRING)? #rowFormatDelimited
+ ;
+
+tableIdentifier
+ : (db=identifier '.')? table=identifier
+ ;
+
+namedExpression
+ : expression (AS? (identifier | identifierList))?
+ ;
+
+namedExpressionSeq
+ : namedExpression (',' namedExpression)*
+ ;
+
+expression
+ : booleanExpression
+ ;
+
+booleanExpression
+ : predicated #booleanDefault
+ | NOT booleanExpression #logicalNot
+ | left=booleanExpression operator=AND right=booleanExpression #logicalBinary
+ | left=booleanExpression operator=OR right=booleanExpression #logicalBinary
+ | EXISTS '(' query ')' #exists
+ ;
+
+// workaround for:
+// https://github.com/antlr/antlr4/issues/780
+// https://github.com/antlr/antlr4/issues/781
+predicated
+ : valueExpression predicate?
+ ;
+
+predicate
+ : NOT? kind=BETWEEN lower=valueExpression AND upper=valueExpression
+ | NOT? kind=IN '(' expression (',' expression)* ')'
+ | NOT? kind=IN '(' query ')'
+ | NOT? kind=(RLIKE | LIKE) pattern=valueExpression
+ | IS NOT? kind=NULL
+ ;
+
+valueExpression
+ : primaryExpression #valueExpressionDefault
+ | operator=(MINUS | PLUS | TILDE) valueExpression #arithmeticUnary
+ | left=valueExpression operator=(ASTERISK | SLASH | PERCENT | DIV) right=valueExpression #arithmeticBinary
+ | left=valueExpression operator=(PLUS | MINUS) right=valueExpression #arithmeticBinary
+ | left=valueExpression operator=AMPERSAND right=valueExpression #arithmeticBinary
+ | left=valueExpression operator=HAT right=valueExpression #arithmeticBinary
+ | left=valueExpression operator=PIPE right=valueExpression #arithmeticBinary
+ | left=valueExpression comparisonOperator right=valueExpression #comparison
+ ;
+
+primaryExpression
+ : constant #constantDefault
+ | ASTERISK #star
+ | qualifiedName '.' ASTERISK #star
+ | '(' expression (',' expression)+ ')' #rowConstructor
+ | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall
+ | '(' query ')' #subqueryExpression
+ | CASE valueExpression whenClause+ (ELSE elseExpression=expression)? END #simpleCase
+ | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase
+ | CAST '(' expression AS dataType ')' #cast
+ | value=primaryExpression '[' index=valueExpression ']' #subscript
+ | identifier #columnReference
+ | base=primaryExpression '.' fieldName=identifier #dereference
+ | '(' expression ')' #parenthesizedExpression
+ ;
+
+constant
+ : NULL #nullLiteral
+ | interval #intervalLiteral
+ | identifier STRING #typeConstructor
+ | number #numericLiteral
+ | booleanValue #booleanLiteral
+ | STRING+ #stringLiteral
+ ;
+
+comparisonOperator
+ : EQ | NEQ | NEQJ | LT | LTE | GT | GTE | NSEQ
+ ;
+
+booleanValue
+ : TRUE | FALSE
+ ;
+
+interval
+ : INTERVAL intervalField*
+ ;
+
+intervalField
+ : value=intervalValue unit=identifier (TO to=identifier)?
+ ;
+
+intervalValue
+ : (PLUS | MINUS)? (INTEGER_VALUE | DECIMAL_VALUE)
+ | STRING
+ ;
+
+dataType
+ : complex=ARRAY '<' dataType '>' #complexDataType
+ | complex=MAP '<' dataType ',' dataType '>' #complexDataType
+ | complex=STRUCT ('<' colTypeList? '>' | NEQ) #complexDataType
+ | identifier ('(' INTEGER_VALUE (',' INTEGER_VALUE)* ')')? #primitiveDataType
+ ;
+
+colTypeList
+ : colType (',' colType)*
+ ;
+
+colType
+ : identifier ':'? dataType (COMMENT STRING)?
+ ;
+
+whenClause
+ : WHEN condition=expression THEN result=expression
+ ;
+
+windows
+ : WINDOW namedWindow (',' namedWindow)*
+ ;
+
+namedWindow
+ : identifier AS windowSpec
+ ;
+
+windowSpec
+ : name=identifier #windowRef
+ | '('
+ ( CLUSTER BY partition+=expression (',' partition+=expression)*
+ | ((PARTITION | DISTRIBUTE) BY partition+=expression (',' partition+=expression)*)?
+ ((ORDER | SORT) BY sortItem (',' sortItem)*)?)
+ windowFrame?
+ ')' #windowDef
+ ;
+
+windowFrame
+ : frameType=RANGE start=frameBound
+ | frameType=ROWS start=frameBound
+ | frameType=RANGE BETWEEN start=frameBound AND end=frameBound
+ | frameType=ROWS BETWEEN start=frameBound AND end=frameBound
+ ;
+
+frameBound
+ : UNBOUNDED boundType=(PRECEDING | FOLLOWING)
+ | boundType=CURRENT ROW
+ | expression boundType=(PRECEDING | FOLLOWING)
+ ;
+
+
+explainOption
+ : LOGICAL | FORMATTED | EXTENDED | CODEGEN
+ ;
+
+transactionMode
+ : ISOLATION LEVEL SNAPSHOT #isolationLevel
+ | READ accessMode=(ONLY | WRITE) #transactionAccessMode
+ ;
+
+qualifiedName
+ : identifier ('.' identifier)*
+ ;
+
+// Identifier that also allows the use of a number of SQL keywords (mainly for backwards compatibility).
+looseIdentifier
+ : identifier
+ | FROM
+ | TO
+ | TABLE
+ | WITH
+ ;
+
+identifier
+ : IDENTIFIER #unquotedIdentifier
+ | quotedIdentifier #quotedIdentifierAlternative
+ | nonReserved #unquotedIdentifier
+ ;
+
+quotedIdentifier
+ : BACKQUOTED_IDENTIFIER
+ ;
+
+number
+ : DECIMAL_VALUE #decimalLiteral
+ | SCIENTIFIC_DECIMAL_VALUE #scientificDecimalLiteral
+ | INTEGER_VALUE #integerLiteral
+ | BIGINT_LITERAL #bigIntLiteral
+ | SMALLINT_LITERAL #smallIntLiteral
+ | TINYINT_LITERAL #tinyIntLiteral
+ | DOUBLE_LITERAL #doubleLiteral
+ ;
+
+nonReserved
+ : SHOW | TABLES | COLUMNS | COLUMN | PARTITIONS | FUNCTIONS | DATABASES
+ | ADD
+ | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | MAP | ARRAY | STRUCT
+ | LATERAL | WINDOW | REDUCE | TRANSFORM | USING | SERDE | SERDEPROPERTIES | RECORDREADER
+ | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED
+ | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | TEMPORARY | OPTIONS
+ | GROUPING | CUBE | ROLLUP
+ | EXPLAIN | FORMAT | LOGICAL | FORMATTED | CODEGEN
+ | TABLESAMPLE | USE | TO | BUCKET | PERCENTLIT | OUT | OF
+ | SET
+ | VIEW | REPLACE
+ | IF
+ | NO | DATA
+ | START | TRANSACTION | COMMIT | ROLLBACK | WORK | ISOLATION | LEVEL
+ | SNAPSHOT | READ | WRITE | ONLY
+ | SORT | CLUSTER | DISTRIBUTE | UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION
+ | EXCHANGE | ARCHIVE | UNARCHIVE | FILEFORMAT | TOUCH | COMPACT | CONCATENATE | CHANGE | FIRST
+ | AFTER | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT
+ | INPUTDRIVER | OUTPUTDRIVER | DBPROPERTIES | DFS | TRUNCATE | METADATA | REPLICATION | COMPUTE
+ | STATISTICS | ANALYZE | PARTITIONED | EXTERNAL | DEFINED | RECORDWRITER
+ | REVOKE | GRANT | LOCK | UNLOCK | MSCK | REPAIR | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE
+ | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEX | INDEXES | LOCKS | OPTION
+ ;
+
+SELECT: 'SELECT';
+FROM: 'FROM';
+ADD: 'ADD';
+AS: 'AS';
+ALL: 'ALL';
+DISTINCT: 'DISTINCT';
+WHERE: 'WHERE';
+GROUP: 'GROUP';
+BY: 'BY';
+GROUPING: 'GROUPING';
+SETS: 'SETS';
+CUBE: 'CUBE';
+ROLLUP: 'ROLLUP';
+ORDER: 'ORDER';
+HAVING: 'HAVING';
+LIMIT: 'LIMIT';
+AT: 'AT';
+OR: 'OR';
+AND: 'AND';
+IN: 'IN';
+NOT: 'NOT' | '!';
+NO: 'NO';
+EXISTS: 'EXISTS';
+BETWEEN: 'BETWEEN';
+LIKE: 'LIKE';
+RLIKE: 'RLIKE' | 'REGEXP';
+IS: 'IS';
+NULL: 'NULL';
+TRUE: 'TRUE';
+FALSE: 'FALSE';
+NULLS: 'NULLS';
+ASC: 'ASC';
+DESC: 'DESC';
+FOR: 'FOR';
+INTERVAL: 'INTERVAL';
+CASE: 'CASE';
+WHEN: 'WHEN';
+THEN: 'THEN';
+ELSE: 'ELSE';
+END: 'END';
+JOIN: 'JOIN';
+CROSS: 'CROSS';
+OUTER: 'OUTER';
+INNER: 'INNER';
+LEFT: 'LEFT';
+SEMI: 'SEMI';
+RIGHT: 'RIGHT';
+FULL: 'FULL';
+NATURAL: 'NATURAL';
+ON: 'ON';
+LATERAL: 'LATERAL';
+WINDOW: 'WINDOW';
+OVER: 'OVER';
+PARTITION: 'PARTITION';
+RANGE: 'RANGE';
+ROWS: 'ROWS';
+UNBOUNDED: 'UNBOUNDED';
+PRECEDING: 'PRECEDING';
+FOLLOWING: 'FOLLOWING';
+CURRENT: 'CURRENT';
+ROW: 'ROW';
+WITH: 'WITH';
+VALUES: 'VALUES';
+CREATE: 'CREATE';
+TABLE: 'TABLE';
+VIEW: 'VIEW';
+REPLACE: 'REPLACE';
+INSERT: 'INSERT';
+DELETE: 'DELETE';
+INTO: 'INTO';
+DESCRIBE: 'DESCRIBE';
+EXPLAIN: 'EXPLAIN';
+FORMAT: 'FORMAT';
+LOGICAL: 'LOGICAL';
+CODEGEN: 'CODEGEN';
+CAST: 'CAST';
+SHOW: 'SHOW';
+TABLES: 'TABLES';
+COLUMNS: 'COLUMNS';
+COLUMN: 'COLUMN';
+USE: 'USE';
+PARTITIONS: 'PARTITIONS';
+FUNCTIONS: 'FUNCTIONS';
+DROP: 'DROP';
+UNION: 'UNION';
+EXCEPT: 'EXCEPT';
+INTERSECT: 'INTERSECT';
+TO: 'TO';
+TABLESAMPLE: 'TABLESAMPLE';
+STRATIFY: 'STRATIFY';
+ALTER: 'ALTER';
+RENAME: 'RENAME';
+ARRAY: 'ARRAY';
+MAP: 'MAP';
+STRUCT: 'STRUCT';
+COMMENT: 'COMMENT';
+SET: 'SET';
+DATA: 'DATA';
+START: 'START';
+TRANSACTION: 'TRANSACTION';
+COMMIT: 'COMMIT';
+ROLLBACK: 'ROLLBACK';
+WORK: 'WORK';
+ISOLATION: 'ISOLATION';
+LEVEL: 'LEVEL';
+SNAPSHOT: 'SNAPSHOT';
+READ: 'READ';
+WRITE: 'WRITE';
+ONLY: 'ONLY';
+MACRO: 'MACRO';
+
+IF: 'IF';
+
+EQ : '=' | '==';
+NSEQ: '<=>';
+NEQ : '<>';
+NEQJ: '!=';
+LT : '<';
+LTE : '<=';
+GT : '>';
+GTE : '>=';
+
+PLUS: '+';
+MINUS: '-';
+ASTERISK: '*';
+SLASH: '/';
+PERCENT: '%';
+DIV: 'DIV';
+TILDE: '~';
+AMPERSAND: '&';
+PIPE: '|';
+HAT: '^';
+
+PERCENTLIT: 'PERCENT';
+BUCKET: 'BUCKET';
+OUT: 'OUT';
+OF: 'OF';
+
+SORT: 'SORT';
+CLUSTER: 'CLUSTER';
+DISTRIBUTE: 'DISTRIBUTE';
+OVERWRITE: 'OVERWRITE';
+TRANSFORM: 'TRANSFORM';
+REDUCE: 'REDUCE';
+USING: 'USING';
+SERDE: 'SERDE';
+SERDEPROPERTIES: 'SERDEPROPERTIES';
+RECORDREADER: 'RECORDREADER';
+RECORDWRITER: 'RECORDWRITER';
+DELIMITED: 'DELIMITED';
+FIELDS: 'FIELDS';
+TERMINATED: 'TERMINATED';
+COLLECTION: 'COLLECTION';
+ITEMS: 'ITEMS';
+KEYS: 'KEYS';
+ESCAPED: 'ESCAPED';
+LINES: 'LINES';
+SEPARATED: 'SEPARATED';
+FUNCTION: 'FUNCTION';
+EXTENDED: 'EXTENDED';
+REFRESH: 'REFRESH';
+CLEAR: 'CLEAR';
+CACHE: 'CACHE';
+UNCACHE: 'UNCACHE';
+LAZY: 'LAZY';
+FORMATTED: 'FORMATTED';
+TEMPORARY: 'TEMPORARY' | 'TEMP';
+OPTIONS: 'OPTIONS';
+UNSET: 'UNSET';
+TBLPROPERTIES: 'TBLPROPERTIES';
+DBPROPERTIES: 'DBPROPERTIES';
+BUCKETS: 'BUCKETS';
+SKEWED: 'SKEWED';
+STORED: 'STORED';
+DIRECTORIES: 'DIRECTORIES';
+LOCATION: 'LOCATION';
+EXCHANGE: 'EXCHANGE';
+ARCHIVE: 'ARCHIVE';
+UNARCHIVE: 'UNARCHIVE';
+FILEFORMAT: 'FILEFORMAT';
+TOUCH: 'TOUCH';
+COMPACT: 'COMPACT';
+CONCATENATE: 'CONCATENATE';
+CHANGE: 'CHANGE';
+FIRST: 'FIRST';
+AFTER: 'AFTER';
+CASCADE: 'CASCADE';
+RESTRICT: 'RESTRICT';
+CLUSTERED: 'CLUSTERED';
+SORTED: 'SORTED';
+PURGE: 'PURGE';
+INPUTFORMAT: 'INPUTFORMAT';
+OUTPUTFORMAT: 'OUTPUTFORMAT';
+INPUTDRIVER: 'INPUTDRIVER';
+OUTPUTDRIVER: 'OUTPUTDRIVER';
+DATABASE: 'DATABASE' | 'SCHEMA';
+DATABASES: 'DATABASES' | 'SCHEMAS';
+DFS: 'DFS';
+TRUNCATE: 'TRUNCATE';
+METADATA: 'METADATA';
+REPLICATION: 'REPLICATION';
+ANALYZE: 'ANALYZE';
+COMPUTE: 'COMPUTE';
+STATISTICS: 'STATISTICS';
+PARTITIONED: 'PARTITIONED';
+EXTERNAL: 'EXTERNAL';
+DEFINED: 'DEFINED';
+REVOKE: 'REVOKE';
+GRANT: 'GRANT';
+LOCK: 'LOCK';
+UNLOCK: 'UNLOCK';
+MSCK: 'MSCK';
+REPAIR: 'REPAIR';
+EXPORT: 'EXPORT';
+IMPORT: 'IMPORT';
+LOAD: 'LOAD';
+ROLE: 'ROLE';
+ROLES: 'ROLES';
+COMPACTIONS: 'COMPACTIONS';
+PRINCIPALS: 'PRINCIPALS';
+TRANSACTIONS: 'TRANSACTIONS';
+INDEX: 'INDEX';
+INDEXES: 'INDEXES';
+LOCKS: 'LOCKS';
+OPTION: 'OPTION';
+ANTI: 'ANTI';
+
+STRING
+ : '\'' ( ~('\''|'\\') | ('\\' .) )* '\''
+ | '\"' ( ~('\"'|'\\') | ('\\' .) )* '\"'
+ ;
+
+BIGINT_LITERAL
+ : DIGIT+ 'L'
+ ;
+
+SMALLINT_LITERAL
+ : DIGIT+ 'S'
+ ;
+
+TINYINT_LITERAL
+ : DIGIT+ 'Y'
+ ;
+
+INTEGER_VALUE
+ : DIGIT+
+ ;
+
+DECIMAL_VALUE
+ : DIGIT+ '.' DIGIT*
+ | '.' DIGIT+
+ ;
+
+SCIENTIFIC_DECIMAL_VALUE
+ : DIGIT+ ('.' DIGIT*)? EXPONENT
+ | '.' DIGIT+ EXPONENT
+ ;
+
+DOUBLE_LITERAL
+ :
+ (INTEGER_VALUE | DECIMAL_VALUE | SCIENTIFIC_DECIMAL_VALUE) 'D'
+ ;
+
+IDENTIFIER
+ : (LETTER | DIGIT | '_')+
+ ;
+
+BACKQUOTED_IDENTIFIER
+ : '`' ( ~'`' | '``' )* '`'
+ ;
+
+fragment EXPONENT
+ : 'E' [+-]? DIGIT+
+ ;
+
+fragment DIGIT
+ : [0-9]
+ ;
+
+fragment LETTER
+ : [A-Z]
+ ;
+
+SIMPLE_COMMENT
+ : '--' ~[\r\n]* '\r'? '\n'? -> channel(HIDDEN)
+ ;
+
+BRACKETED_COMMENT
+ : '/*' .*? '*/' -> channel(HIDDEN)
+ ;
+
+WS
+ : [ \r\n\t]+ -> channel(HIDDEN)
+ ;
+
+// Catch-all for anything we can't recognize.
+// We use this to be able to ignore and recover all the text
+// when splitting statements with DelimiterLexer
+UNRECOGNIZED
+ : .
+ ;
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java
deleted file mode 100644
index 01f89112a7..0000000000
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java
+++ /dev/null
@@ -1,135 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.parser;
-
-import java.nio.charset.StandardCharsets;
-
-/**
- * A couple of utility methods that help with parsing ASTs.
- *
- * The 'unescapeSQLString' method in this class was take from the SemanticAnalyzer in Hive:
- * ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java
- */
-public final class ParseUtils {
- private ParseUtils() {
- super();
- }
-
- private static final int[] multiplier = new int[] {1000, 100, 10, 1};
-
- @SuppressWarnings("nls")
- public static String unescapeSQLString(String b) {
- Character enclosure = null;
-
- // Some of the strings can be passed in as unicode. For example, the
- // delimiter can be passed in as \002 - So, we first check if the
- // string is a unicode number, else go back to the old behavior
- StringBuilder sb = new StringBuilder(b.length());
- for (int i = 0; i < b.length(); i++) {
-
- char currentChar = b.charAt(i);
- if (enclosure == null) {
- if (currentChar == '\'' || b.charAt(i) == '\"') {
- enclosure = currentChar;
- }
- // ignore all other chars outside the enclosure
- continue;
- }
-
- if (enclosure.equals(currentChar)) {
- enclosure = null;
- continue;
- }
-
- if (currentChar == '\\' && (i + 6 < b.length()) && b.charAt(i + 1) == 'u') {
- int code = 0;
- int base = i + 2;
- for (int j = 0; j < 4; j++) {
- int digit = Character.digit(b.charAt(j + base), 16);
- code += digit * multiplier[j];
- }
- sb.append((char)code);
- i += 5;
- continue;
- }
-
- if (currentChar == '\\' && (i + 4 < b.length())) {
- char i1 = b.charAt(i + 1);
- char i2 = b.charAt(i + 2);
- char i3 = b.charAt(i + 3);
- if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7')
- && (i3 >= '0' && i3 <= '7')) {
- byte bVal = (byte) ((i3 - '0') + ((i2 - '0') * 8) + ((i1 - '0') * 8 * 8));
- byte[] bValArr = new byte[1];
- bValArr[0] = bVal;
- String tmp = new String(bValArr, StandardCharsets.UTF_8);
- sb.append(tmp);
- i += 3;
- continue;
- }
- }
-
- if (currentChar == '\\' && (i + 2 < b.length())) {
- char n = b.charAt(i + 1);
- switch (n) {
- case '0':
- sb.append("\0");
- break;
- case '\'':
- sb.append("'");
- break;
- case '"':
- sb.append("\"");
- break;
- case 'b':
- sb.append("\b");
- break;
- case 'n':
- sb.append("\n");
- break;
- case 'r':
- sb.append("\r");
- break;
- case 't':
- sb.append("\t");
- break;
- case 'Z':
- sb.append("\u001A");
- break;
- case '\\':
- sb.append("\\");
- break;
- // The following 2 lines are exactly what MySQL does TODO: why do we do this?
- case '%':
- sb.append("\\%");
- break;
- case '_':
- sb.append("\\_");
- break;
- default:
- sb.append(n);
- }
- i++;
- } else {
- sb.append(currentChar);
- }
- }
- return sb.toString();
- }
-}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index aa7fc2121e..7784345a7a 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -151,7 +151,7 @@ public final class UnsafeExternalRowSorter {
Platform.throwException(e);
}
throw new RuntimeException("Exception should have been re-thrown in next()");
- };
+ }
};
} catch (IOException e) {
cleanupResources();
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index b19538a23f..ffa694fcdc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -17,22 +17,20 @@
package org.apache.spark.sql
-import java.lang.reflect.Modifier
-
import scala.annotation.implicitNotFound
-import scala.reflect.{classTag, ClassTag}
+import scala.reflect.ClassTag
import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
-import org.apache.spark.sql.catalyst.expressions.{BoundReference, DecodeUsingSerializer, EncodeUsingSerializer}
import org.apache.spark.sql.types._
+
/**
* :: Experimental ::
* Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
*
* == Scala ==
- * Encoders are generally created automatically through implicits from a `SQLContext`.
+ * Encoders are generally created automatically through implicits from a `SQLContext`, or can be
+ * explicitly created by calling static methods on [[Encoders]].
*
* {{{
* import sqlContext.implicits._
@@ -81,224 +79,3 @@ trait Encoder[T] extends Serializable {
/** A ClassTag that can be used to construct and Array to contain a collection of `T`. */
def clsTag: ClassTag[T]
}
-
-/**
- * :: Experimental ::
- * Methods for creating an [[Encoder]].
- *
- * @since 1.6.0
- */
-@Experimental
-object Encoders {
-
- /**
- * An encoder for nullable boolean type.
- * @since 1.6.0
- */
- def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder()
-
- /**
- * An encoder for nullable byte type.
- * @since 1.6.0
- */
- def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder()
-
- /**
- * An encoder for nullable short type.
- * @since 1.6.0
- */
- def SHORT: Encoder[java.lang.Short] = ExpressionEncoder()
-
- /**
- * An encoder for nullable int type.
- * @since 1.6.0
- */
- def INT: Encoder[java.lang.Integer] = ExpressionEncoder()
-
- /**
- * An encoder for nullable long type.
- * @since 1.6.0
- */
- def LONG: Encoder[java.lang.Long] = ExpressionEncoder()
-
- /**
- * An encoder for nullable float type.
- * @since 1.6.0
- */
- def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder()
-
- /**
- * An encoder for nullable double type.
- * @since 1.6.0
- */
- def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder()
-
- /**
- * An encoder for nullable string type.
- * @since 1.6.0
- */
- def STRING: Encoder[java.lang.String] = ExpressionEncoder()
-
- /**
- * An encoder for nullable decimal type.
- * @since 1.6.0
- */
- def DECIMAL: Encoder[java.math.BigDecimal] = ExpressionEncoder()
-
- /**
- * An encoder for nullable date type.
- * @since 1.6.0
- */
- def DATE: Encoder[java.sql.Date] = ExpressionEncoder()
-
- /**
- * An encoder for nullable timestamp type.
- * @since 1.6.0
- */
- def TIMESTAMP: Encoder[java.sql.Timestamp] = ExpressionEncoder()
-
- /**
- * An encoder for arrays of bytes.
- * @since 1.6.1
- */
- def BINARY: Encoder[Array[Byte]] = ExpressionEncoder()
-
- /**
- * Creates an encoder for Java Bean of type T.
- *
- * T must be publicly accessible.
- *
- * supported types for java bean field:
- * - primitive types: boolean, int, double, etc.
- * - boxed types: Boolean, Integer, Double, etc.
- * - String
- * - java.math.BigDecimal
- * - time related: java.sql.Date, java.sql.Timestamp
- * - collection types: only array and java.util.List currently, map support is in progress
- * - nested java bean.
- *
- * @since 1.6.0
- */
- def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass)
-
- /**
- * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo.
- * This encoder maps T into a single byte array (binary) field.
- *
- * T must be publicly accessible.
- *
- * @since 1.6.0
- */
- def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true)
-
- /**
- * Creates an encoder that serializes objects of type T using Kryo.
- * This encoder maps T into a single byte array (binary) field.
- *
- * T must be publicly accessible.
- *
- * @since 1.6.0
- */
- def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz))
-
- /**
- * (Scala-specific) Creates an encoder that serializes objects of type T using generic Java
- * serialization. This encoder maps T into a single byte array (binary) field.
- *
- * Note that this is extremely inefficient and should only be used as the last resort.
- *
- * T must be publicly accessible.
- *
- * @since 1.6.0
- */
- def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false)
-
- /**
- * Creates an encoder that serializes objects of type T using generic Java serialization.
- * This encoder maps T into a single byte array (binary) field.
- *
- * Note that this is extremely inefficient and should only be used as the last resort.
- *
- * T must be publicly accessible.
- *
- * @since 1.6.0
- */
- def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz))
-
- /** Throws an exception if T is not a public class. */
- private def validatePublicClass[T: ClassTag](): Unit = {
- if (!Modifier.isPublic(classTag[T].runtimeClass.getModifiers)) {
- throw new UnsupportedOperationException(
- s"${classTag[T].runtimeClass.getName} is not a public class. " +
- "Only public classes are supported.")
- }
- }
-
- /** A way to construct encoders using generic serializers. */
- private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = {
- if (classTag[T].runtimeClass.isPrimitive) {
- throw new UnsupportedOperationException("Primitive types are not supported.")
- }
-
- validatePublicClass[T]()
-
- ExpressionEncoder[T](
- schema = new StructType().add("value", BinaryType),
- flat = true,
- toRowExpressions = Seq(
- EncodeUsingSerializer(
- BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)),
- fromRowExpression =
- DecodeUsingSerializer[T](
- BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo),
- clsTag = classTag[T]
- )
- }
-
- /**
- * An encoder for 2-ary tuples.
- * @since 1.6.0
- */
- def tuple[T1, T2](
- e1: Encoder[T1],
- e2: Encoder[T2]): Encoder[(T1, T2)] = {
- ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2))
- }
-
- /**
- * An encoder for 3-ary tuples.
- * @since 1.6.0
- */
- def tuple[T1, T2, T3](
- e1: Encoder[T1],
- e2: Encoder[T2],
- e3: Encoder[T3]): Encoder[(T1, T2, T3)] = {
- ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3))
- }
-
- /**
- * An encoder for 4-ary tuples.
- * @since 1.6.0
- */
- def tuple[T1, T2, T3, T4](
- e1: Encoder[T1],
- e2: Encoder[T2],
- e3: Encoder[T3],
- e4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = {
- ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4))
- }
-
- /**
- * An encoder for 5-ary tuples.
- * @since 1.6.0
- */
- def tuple[T1, T2, T3, T4, T5](
- e1: Encoder[T1],
- e2: Encoder[T2],
- e3: Encoder[T3],
- e4: Encoder[T4],
- e5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = {
- ExpressionEncoder.tuple(
- encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4), encoderFor(e5))
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
new file mode 100644
index 0000000000..3f4df704db
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
@@ -0,0 +1,314 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.lang.reflect.Modifier
+
+import scala.reflect.{classTag, ClassTag}
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, DecodeUsingSerializer, EncodeUsingSerializer}
+import org.apache.spark.sql.types._
+
+/**
+ * :: Experimental ::
+ * Methods for creating an [[Encoder]].
+ *
+ * @since 1.6.0
+ */
+@Experimental
+object Encoders {
+
+ /**
+ * An encoder for nullable boolean type.
+ * The Scala primitive encoder is available as [[scalaBoolean]].
+ * @since 1.6.0
+ */
+ def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder()
+
+ /**
+ * An encoder for nullable byte type.
+ * The Scala primitive encoder is available as [[scalaByte]].
+ * @since 1.6.0
+ */
+ def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder()
+
+ /**
+ * An encoder for nullable short type.
+ * The Scala primitive encoder is available as [[scalaShort]].
+ * @since 1.6.0
+ */
+ def SHORT: Encoder[java.lang.Short] = ExpressionEncoder()
+
+ /**
+ * An encoder for nullable int type.
+ * The Scala primitive encoder is available as [[scalaInt]].
+ * @since 1.6.0
+ */
+ def INT: Encoder[java.lang.Integer] = ExpressionEncoder()
+
+ /**
+ * An encoder for nullable long type.
+ * The Scala primitive encoder is available as [[scalaLong]].
+ * @since 1.6.0
+ */
+ def LONG: Encoder[java.lang.Long] = ExpressionEncoder()
+
+ /**
+ * An encoder for nullable float type.
+ * The Scala primitive encoder is available as [[scalaFloat]].
+ * @since 1.6.0
+ */
+ def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder()
+
+ /**
+ * An encoder for nullable double type.
+ * The Scala primitive encoder is available as [[scalaDouble]].
+ * @since 1.6.0
+ */
+ def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder()
+
+ /**
+ * An encoder for nullable string type.
+ *
+ * @since 1.6.0
+ */
+ def STRING: Encoder[java.lang.String] = ExpressionEncoder()
+
+ /**
+ * An encoder for nullable decimal type.
+ *
+ * @since 1.6.0
+ */
+ def DECIMAL: Encoder[java.math.BigDecimal] = ExpressionEncoder()
+
+ /**
+ * An encoder for nullable date type.
+ *
+ * @since 1.6.0
+ */
+ def DATE: Encoder[java.sql.Date] = ExpressionEncoder()
+
+ /**
+ * An encoder for nullable timestamp type.
+ *
+ * @since 1.6.0
+ */
+ def TIMESTAMP: Encoder[java.sql.Timestamp] = ExpressionEncoder()
+
+ /**
+ * An encoder for arrays of bytes.
+ *
+ * @since 1.6.1
+ */
+ def BINARY: Encoder[Array[Byte]] = ExpressionEncoder()
+
+ /**
+ * Creates an encoder for Java Bean of type T.
+ *
+ * T must be publicly accessible.
+ *
+ * supported types for java bean field:
+ * - primitive types: boolean, int, double, etc.
+ * - boxed types: Boolean, Integer, Double, etc.
+ * - String
+ * - java.math.BigDecimal
+ * - time related: java.sql.Date, java.sql.Timestamp
+ * - collection types: only array and java.util.List currently, map support is in progress
+ * - nested java bean.
+ *
+ * @since 1.6.0
+ */
+ def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass)
+
+ /**
+ * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo.
+ * This encoder maps T into a single byte array (binary) field.
+ *
+ * T must be publicly accessible.
+ *
+ * @since 1.6.0
+ */
+ def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true)
+
+ /**
+ * Creates an encoder that serializes objects of type T using Kryo.
+ * This encoder maps T into a single byte array (binary) field.
+ *
+ * T must be publicly accessible.
+ *
+ * @since 1.6.0
+ */
+ def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz))
+
+ /**
+ * (Scala-specific) Creates an encoder that serializes objects of type T using generic Java
+ * serialization. This encoder maps T into a single byte array (binary) field.
+ *
+ * Note that this is extremely inefficient and should only be used as the last resort.
+ *
+ * T must be publicly accessible.
+ *
+ * @since 1.6.0
+ */
+ def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false)
+
+ /**
+ * Creates an encoder that serializes objects of type T using generic Java serialization.
+ * This encoder maps T into a single byte array (binary) field.
+ *
+ * Note that this is extremely inefficient and should only be used as the last resort.
+ *
+ * T must be publicly accessible.
+ *
+ * @since 1.6.0
+ */
+ def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz))
+
+ /** Throws an exception if T is not a public class. */
+ private def validatePublicClass[T: ClassTag](): Unit = {
+ if (!Modifier.isPublic(classTag[T].runtimeClass.getModifiers)) {
+ throw new UnsupportedOperationException(
+ s"${classTag[T].runtimeClass.getName} is not a public class. " +
+ "Only public classes are supported.")
+ }
+ }
+
+ /** A way to construct encoders using generic serializers. */
+ private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = {
+ if (classTag[T].runtimeClass.isPrimitive) {
+ throw new UnsupportedOperationException("Primitive types are not supported.")
+ }
+
+ validatePublicClass[T]()
+
+ ExpressionEncoder[T](
+ schema = new StructType().add("value", BinaryType),
+ flat = true,
+ serializer = Seq(
+ EncodeUsingSerializer(
+ BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)),
+ deserializer =
+ DecodeUsingSerializer[T](
+ BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo),
+ clsTag = classTag[T]
+ )
+ }
+
+ /**
+ * An encoder for 2-ary tuples.
+ *
+ * @since 1.6.0
+ */
+ def tuple[T1, T2](
+ e1: Encoder[T1],
+ e2: Encoder[T2]): Encoder[(T1, T2)] = {
+ ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2))
+ }
+
+ /**
+ * An encoder for 3-ary tuples.
+ *
+ * @since 1.6.0
+ */
+ def tuple[T1, T2, T3](
+ e1: Encoder[T1],
+ e2: Encoder[T2],
+ e3: Encoder[T3]): Encoder[(T1, T2, T3)] = {
+ ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3))
+ }
+
+ /**
+ * An encoder for 4-ary tuples.
+ *
+ * @since 1.6.0
+ */
+ def tuple[T1, T2, T3, T4](
+ e1: Encoder[T1],
+ e2: Encoder[T2],
+ e3: Encoder[T3],
+ e4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = {
+ ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4))
+ }
+
+ /**
+ * An encoder for 5-ary tuples.
+ *
+ * @since 1.6.0
+ */
+ def tuple[T1, T2, T3, T4, T5](
+ e1: Encoder[T1],
+ e2: Encoder[T2],
+ e3: Encoder[T3],
+ e4: Encoder[T4],
+ e5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = {
+ ExpressionEncoder.tuple(
+ encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4), encoderFor(e5))
+ }
+
+ /**
+ * An encoder for Scala's product type (tuples, case classes, etc).
+ * @since 2.0.0
+ */
+ def product[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder()
+
+ /**
+ * An encoder for Scala's primitive int type.
+ * @since 2.0.0
+ */
+ def scalaInt: Encoder[Int] = ExpressionEncoder()
+
+ /**
+ * An encoder for Scala's primitive long type.
+ * @since 2.0.0
+ */
+ def scalaLong: Encoder[Long] = ExpressionEncoder()
+
+ /**
+ * An encoder for Scala's primitive double type.
+ * @since 2.0.0
+ */
+ def scalaDouble: Encoder[Double] = ExpressionEncoder()
+
+ /**
+ * An encoder for Scala's primitive float type.
+ * @since 2.0.0
+ */
+ def scalaFloat: Encoder[Float] = ExpressionEncoder()
+
+ /**
+ * An encoder for Scala's primitive byte type.
+ * @since 2.0.0
+ */
+ def scalaByte: Encoder[Byte] = ExpressionEncoder()
+
+ /**
+ * An encoder for Scala's primitive short type.
+ * @since 2.0.0
+ */
+ def scalaShort: Encoder[Short] = ExpressionEncoder()
+
+ /**
+ * An encoder for Scala's primitive boolean type.
+ * @since 2.0.0
+ */
+ def scalaBoolean: Encoder[Boolean] = ExpressionEncoder()
+
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
index d5ac01500b..2b98aacdd7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
@@ -26,7 +26,7 @@ private[spark] trait CatalystConf {
def groupByOrdinal: Boolean
/**
- * Returns the [[Resolver]] for the current configuration, which can be used to determin if two
+ * Returns the [[Resolver]] for the current configuration, which can be used to determine if two
* identifiers are equal.
*/
def resolver: Resolver = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 59ee41d02f..6f9fbbbead 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -155,16 +155,16 @@ object JavaTypeInference {
}
/**
- * Returns an expression that can be used to construct an object of java bean `T` given an input
- * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
+ * Returns an expression that can be used to deserialize an internal row to an object of java bean
+ * `T` with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
* of the same name as the constructor arguments. Nested classes will have their fields accessed
* using UnresolvedExtractValue.
*/
- def constructorFor(beanClass: Class[_]): Expression = {
- constructorFor(TypeToken.of(beanClass), None)
+ def deserializerFor(beanClass: Class[_]): Expression = {
+ deserializerFor(TypeToken.of(beanClass), None)
}
- private def constructorFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = {
+ private def deserializerFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = {
/** Returns the current path with a sub-field extracted. */
def addToPath(part: String): Expression = path
.map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
@@ -231,7 +231,7 @@ object JavaTypeInference {
}.getOrElse {
Invoke(
MapObjects(
- p => constructorFor(typeToken.getComponentType, Some(p)),
+ p => deserializerFor(typeToken.getComponentType, Some(p)),
getPath,
inferDataType(elementType)._1),
"array",
@@ -243,7 +243,7 @@ object JavaTypeInference {
val array =
Invoke(
MapObjects(
- p => constructorFor(et, Some(p)),
+ p => deserializerFor(et, Some(p)),
getPath,
inferDataType(et)._1),
"array",
@@ -259,7 +259,7 @@ object JavaTypeInference {
val keyData =
Invoke(
MapObjects(
- p => constructorFor(keyType, Some(p)),
+ p => deserializerFor(keyType, Some(p)),
Invoke(getPath, "keyArray", ArrayType(keyDataType)),
keyDataType),
"array",
@@ -268,7 +268,7 @@ object JavaTypeInference {
val valueData =
Invoke(
MapObjects(
- p => constructorFor(valueType, Some(p)),
+ p => deserializerFor(valueType, Some(p)),
Invoke(getPath, "valueArray", ArrayType(valueDataType)),
valueDataType),
"array",
@@ -288,7 +288,7 @@ object JavaTypeInference {
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
val (_, nullable) = inferDataType(fieldType)
- val constructor = constructorFor(fieldType, Some(addToPath(fieldName)))
+ val constructor = deserializerFor(fieldType, Some(addToPath(fieldName)))
val setter = if (nullable) {
constructor
} else {
@@ -313,14 +313,14 @@ object JavaTypeInference {
}
/**
- * Returns expressions for extracting all the fields from the given type.
+ * Returns an expression for serializing an object of the given type to an internal row.
*/
- def extractorsFor(beanClass: Class[_]): CreateNamedStruct = {
+ def serializerFor(beanClass: Class[_]): CreateNamedStruct = {
val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true)
- extractorFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct]
+ serializerFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct]
}
- private def extractorFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = {
+ private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = {
def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = {
val (dataType, nullable) = inferDataType(elementType)
@@ -330,7 +330,7 @@ object JavaTypeInference {
input :: Nil,
dataType = ArrayType(dataType, nullable))
} else {
- MapObjects(extractorFor(_, elementType), input, ObjectType(elementType.getRawType))
+ MapObjects(serializerFor(_, elementType), input, ObjectType(elementType.getRawType))
}
}
@@ -403,7 +403,7 @@ object JavaTypeInference {
inputObject,
p.getReadMethod.getName,
inferExternalType(fieldType.getRawType))
- expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil
+ expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil
})
} else {
throw new UnsupportedOperationException(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index f208401160..4795fc2557 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -110,8 +110,8 @@ object ScalaReflection extends ScalaReflection {
}
/**
- * Returns an expression that can be used to construct an object of type `T` given an input
- * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
+ * Returns an expression that can be used to deserialize an input row to an object of type `T`
+ * with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
* of the same name as the constructor arguments. Nested classes will have their fields accessed
* using UnresolvedExtractValue.
*
@@ -119,14 +119,14 @@ object ScalaReflection extends ScalaReflection {
* from ordinal 0 (since there are no names to map to). The actual location can be moved by
* calling resolve/bind with a new schema.
*/
- def constructorFor[T : TypeTag]: Expression = {
+ def deserializerFor[T : TypeTag]: Expression = {
val tpe = localTypeOf[T]
val clsName = getClassNameFromType(tpe)
val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil
- constructorFor(tpe, None, walkedTypePath)
+ deserializerFor(tpe, None, walkedTypePath)
}
- private def constructorFor(
+ private def deserializerFor(
tpe: `Type`,
path: Option[Expression],
walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized {
@@ -161,7 +161,7 @@ object ScalaReflection extends ScalaReflection {
}
/**
- * When we build the `fromRowExpression` for an encoder, we set up a lot of "unresolved" stuff
+ * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff
* and lost the required data type, which may lead to runtime error if the real type doesn't
* match the encoder's schema.
* For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type
@@ -188,7 +188,7 @@ object ScalaReflection extends ScalaReflection {
val TypeRef(_, _, Seq(optType)) = t
val className = getClassNameFromType(optType)
val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath
- WrapOption(constructorFor(optType, path, newTypePath), dataTypeFor(optType))
+ WrapOption(deserializerFor(optType, path, newTypePath), dataTypeFor(optType))
case t if t <:< localTypeOf[java.lang.Integer] =>
val boxedType = classOf[java.lang.Integer]
@@ -272,7 +272,7 @@ object ScalaReflection extends ScalaReflection {
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
Invoke(
MapObjects(
- p => constructorFor(elementType, Some(p), newTypePath),
+ p => deserializerFor(elementType, Some(p), newTypePath),
getPath,
schemaFor(elementType).dataType),
"array",
@@ -286,7 +286,7 @@ object ScalaReflection extends ScalaReflection {
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
val mapFunction: Expression => Expression = p => {
- val converter = constructorFor(elementType, Some(p), newTypePath)
+ val converter = deserializerFor(elementType, Some(p), newTypePath)
if (nullable) {
converter
} else {
@@ -312,7 +312,7 @@ object ScalaReflection extends ScalaReflection {
val keyData =
Invoke(
MapObjects(
- p => constructorFor(keyType, Some(p), walkedTypePath),
+ p => deserializerFor(keyType, Some(p), walkedTypePath),
Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)),
schemaFor(keyType).dataType),
"array",
@@ -321,7 +321,7 @@ object ScalaReflection extends ScalaReflection {
val valueData =
Invoke(
MapObjects(
- p => constructorFor(valueType, Some(p), walkedTypePath),
+ p => deserializerFor(valueType, Some(p), walkedTypePath),
Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)),
schemaFor(valueType).dataType),
"array",
@@ -344,12 +344,12 @@ object ScalaReflection extends ScalaReflection {
val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
// For tuples, we based grab the inner fields by ordinal instead of name.
if (cls.getName startsWith "scala.Tuple") {
- constructorFor(
+ deserializerFor(
fieldType,
Some(addToPathOrdinal(i, dataType, newTypePath)),
newTypePath)
} else {
- val constructor = constructorFor(
+ val constructor = deserializerFor(
fieldType,
Some(addToPath(fieldName, dataType, newTypePath)),
newTypePath)
@@ -387,7 +387,7 @@ object ScalaReflection extends ScalaReflection {
}
/**
- * Returns expressions for extracting all the fields from the given type.
+ * Returns an expression for serializing an object of type T to an internal row.
*
* If the given type is not supported, i.e. there is no encoder can be built for this type,
* an [[UnsupportedOperationException]] will be thrown with detailed error message to explain
@@ -398,18 +398,18 @@ object ScalaReflection extends ScalaReflection {
* * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"`
* * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")`
*/
- def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = {
+ def serializerFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = {
val tpe = localTypeOf[T]
val clsName = getClassNameFromType(tpe)
val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil
- extractorFor(inputObject, tpe, walkedTypePath) match {
+ serializerFor(inputObject, tpe, walkedTypePath) match {
case expressions.If(_, _, s: CreateNamedStruct) if tpe <:< localTypeOf[Product] => s
case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil)
}
}
/** Helper for extracting internal fields from a case class. */
- private def extractorFor(
+ private def serializerFor(
inputObject: Expression,
tpe: `Type`,
walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized {
@@ -425,7 +425,7 @@ object ScalaReflection extends ScalaReflection {
} else {
val clsName = getClassNameFromType(elementType)
val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath
- MapObjects(extractorFor(_, elementType, newPath), input, externalDataType)
+ MapObjects(serializerFor(_, elementType, newPath), input, externalDataType)
}
}
@@ -491,7 +491,7 @@ object ScalaReflection extends ScalaReflection {
expressions.If(
IsNull(unwrapped),
expressions.Literal.create(null, silentSchemaFor(optType).dataType),
- extractorFor(unwrapped, optType, newPath))
+ serializerFor(unwrapped, optType, newPath))
}
case t if t <:< localTypeOf[Product] =>
@@ -500,7 +500,7 @@ object ScalaReflection extends ScalaReflection {
val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
val clsName = getClassNameFromType(fieldType)
val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
- expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType, newPath) :: Nil
+ expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil
})
val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
@@ -762,15 +762,15 @@ trait ScalaReflection {
}
/**
- * Returns the full class name for a type. The returned name is the canonical
- * Scala name, where each component is separated by a period. It is NOT the
- * Java-equivalent runtime name (no dollar signs).
- *
- * In simple cases, both the Scala and Java names are the same, however when Scala
- * generates constructs that do not map to a Java equivalent, such as singleton objects
- * or nested classes in package objects, it uses the dollar sign ($) to create
- * synthetic classes, emulating behaviour in Java bytecode.
- */
+ * Returns the full class name for a type. The returned name is the canonical
+ * Scala name, where each component is separated by a period. It is NOT the
+ * Java-equivalent runtime name (no dollar signs).
+ *
+ * In simple cases, both the Scala and Java names are the same, however when Scala
+ * generates constructs that do not map to a Java equivalent, such as singleton objects
+ * or nested classes in package objects, it uses the dollar sign ($) to create
+ * synthetic classes, emulating behaviour in Java bytecode.
+ */
def getClassNameFromType(tpe: `Type`): String = {
tpe.erasure.typeSymbol.asClass.fullName
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 3b83e68018..de40ddde1b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.analysis
-import java.lang.reflect.Modifier
-
import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer
@@ -42,9 +40,12 @@ import org.apache.spark.sql.types._
* to resolve attribute references.
*/
object SimpleAnalyzer
- extends SimpleAnalyzer(new SimpleCatalystConf(caseSensitiveAnalysis = true))
-class SimpleAnalyzer(conf: CatalystConf)
- extends Analyzer(new SessionCatalog(new InMemoryCatalog, conf), EmptyFunctionRegistry, conf)
+ extends SimpleAnalyzer(
+ EmptyFunctionRegistry,
+ new SimpleCatalystConf(caseSensitiveAnalysis = true))
+
+class SimpleAnalyzer(functionRegistry: FunctionRegistry, conf: CatalystConf)
+ extends Analyzer(new SessionCatalog(new InMemoryCatalog, functionRegistry, conf), conf)
/**
* Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and
@@ -53,7 +54,6 @@ class SimpleAnalyzer(conf: CatalystConf)
*/
class Analyzer(
catalog: SessionCatalog,
- registry: FunctionRegistry,
conf: CatalystConf,
maxIterations: Int = 100)
extends RuleExecutor[LogicalPlan] with CheckAnalysis {
@@ -81,11 +81,13 @@ class Analyzer(
Batch("Resolution", fixedPoint,
ResolveRelations ::
ResolveReferences ::
+ ResolveDeserializer ::
+ ResolveNewInstance ::
+ ResolveUpCast ::
ResolveGroupingAnalytics ::
ResolvePivot ::
- ResolveUpCast ::
ResolveOrdinalInOrderByAndGroupBy ::
- ResolveSortReferences ::
+ ResolveMissingReferences ::
ResolveGenerate ::
ResolveFunctions ::
ResolveAliases ::
@@ -96,6 +98,7 @@ class Analyzer(
ExtractWindowExpressions ::
GlobalAggregates ::
ResolveAggregateFunctions ::
+ TimeWindowing ::
HiveTypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Nondeterministic", Once,
@@ -225,21 +228,56 @@ class Analyzer(
Seq.tabulate(1 << c.groupByExprs.length)(i => i)
}
- private def hasGroupingId(expr: Seq[Expression]): Boolean = {
- expr.exists(_.collectFirst {
- case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.groupingIdName) => u
- }.isDefined)
+ private def hasGroupingAttribute(expr: Expression): Boolean = {
+ expr.collectFirst {
+ case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.hiveGroupingIdName) => u
+ }.isDefined
}
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ private def hasGroupingFunction(e: Expression): Boolean = {
+ e.collectFirst {
+ case g: Grouping => g
+ case g: GroupingID => g
+ }.isDefined
+ }
+
+ private def replaceGroupingFunc(
+ expr: Expression,
+ groupByExprs: Seq[Expression],
+ gid: Expression): Expression = {
+ expr transform {
+ case e: GroupingID =>
+ if (e.groupByExprs.isEmpty || e.groupByExprs == groupByExprs) {
+ gid
+ } else {
+ throw new AnalysisException(
+ s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " +
+ s"grouping columns (${groupByExprs.mkString(",")})")
+ }
+ case Grouping(col: Expression) =>
+ val idx = groupByExprs.indexOf(col)
+ if (idx >= 0) {
+ Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)),
+ Literal(1)), ByteType)
+ } else {
+ throw new AnalysisException(s"Column of grouping ($col) can't be found " +
+ s"in grouping columns ${groupByExprs.mkString(",")}")
+ }
+ }
+ }
+
+ // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case a if !a.childrenResolved => a // be sure all of the children are resolved.
+ case p if p.expressions.exists(hasGroupingAttribute) =>
+ failAnalysis(
+ s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead")
+
case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) =>
GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions)
case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) =>
GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions)
- case g: GroupingSets if g.expressions.exists(!_.resolved) && hasGroupingId(g.expressions) =>
- failAnalysis(
- s"${VirtualColumn.groupingIdName} is deprecated; use grouping_id() instead")
+
// Ensure all the expressions have been resolved.
case x: GroupingSets if x.expressions.forall(_.resolved) =>
val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
@@ -267,7 +305,7 @@ class Analyzer(
def isPartOfAggregation(e: Expression): Boolean = {
aggsBuffer.exists(a => a.find(_ eq e).isDefined)
}
- expr.transformDown {
+ replaceGroupingFunc(expr, x.groupByExprs, gid).transformDown {
// AggregateExpression should be computed on the unmodified value of its argument
// expressions, so we should not replace any references to grouping expression
// inside it.
@@ -275,23 +313,6 @@ class Analyzer(
aggsBuffer += e
e
case e if isPartOfAggregation(e) => e
- case e: GroupingID =>
- if (e.groupByExprs.isEmpty || e.groupByExprs == x.groupByExprs) {
- gid
- } else {
- throw new AnalysisException(
- s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " +
- s"grouping columns (${x.groupByExprs.mkString(",")})")
- }
- case Grouping(col: Expression) =>
- val idx = x.groupByExprs.indexOf(col)
- if (idx >= 0) {
- Cast(BitwiseAnd(ShiftRight(gid, Literal(x.groupByExprs.length - 1 - idx)),
- Literal(1)), ByteType)
- } else {
- throw new AnalysisException(s"Column of grouping ($col) can't be found " +
- s"in grouping columns ${x.groupByExprs.mkString(",")}")
- }
case e =>
val index = groupByAliases.indexWhere(_.child.semanticEquals(e))
if (index == -1) {
@@ -303,9 +324,37 @@ class Analyzer(
}
Aggregate(
- groupByAttributes :+ VirtualColumn.groupingIdAttribute,
+ groupByAttributes :+ gid,
aggregations,
Expand(x.bitmasks, groupByAliases, groupByAttributes, gid, x.child))
+
+ case f @ Filter(cond, child) if hasGroupingFunction(cond) =>
+ val groupingExprs = findGroupingExprs(child)
+ // The unresolved grouping id will be resolved by ResolveMissingReferences
+ val newCond = replaceGroupingFunc(cond, groupingExprs, VirtualColumn.groupingIdAttribute)
+ f.copy(condition = newCond)
+
+ case s @ Sort(order, _, child) if order.exists(hasGroupingFunction) =>
+ val groupingExprs = findGroupingExprs(child)
+ val gid = VirtualColumn.groupingIdAttribute
+ // The unresolved grouping id will be resolved by ResolveMissingReferences
+ val newOrder = order.map(replaceGroupingFunc(_, groupingExprs, gid).asInstanceOf[SortOrder])
+ s.copy(order = newOrder)
+ }
+
+ private def findGroupingExprs(plan: LogicalPlan): Seq[Expression] = {
+ plan.collectFirst {
+ case a: Aggregate =>
+ // this Aggregate should have grouping id as the last grouping key.
+ val gid = a.groupingExpressions.last
+ if (!gid.isInstanceOf[AttributeReference]
+ || gid.asInstanceOf[AttributeReference].name != VirtualColumn.groupingIdName) {
+ failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
+ }
+ a.groupingExpressions.take(a.groupingExpressions.length - 1)
+ }.getOrElse {
+ failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
+ }
}
}
@@ -329,6 +378,11 @@ class Analyzer(
Last(ifExpr(expr), Literal(true))
case a: AggregateFunction =>
a.withNewChildren(a.children.map(ifExpr))
+ }.transform {
+ // We are duplicating aggregates that are now computing a different value for each
+ // pivot value.
+ // TODO: Don't construct the physical container until after analysis.
+ case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId)
}
if (filteredAggregate.fastEquals(aggregate)) {
throw new AnalysisException(
@@ -355,7 +409,7 @@ class Analyzer(
catalog.lookupRelation(u.tableIdentifier, u.alias)
} catch {
case _: NoSuchTableException =>
- u.failAnalysis(s"Table not found: ${u.tableName}")
+ u.failAnalysis(s"Table or View not found: ${u.tableName}")
}
}
@@ -487,18 +541,9 @@ class Analyzer(
Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child)
}
- // A special case for ObjectOperator, because the deserializer expressions in ObjectOperator
- // should be resolved by their corresponding attributes instead of children's output.
- case o: ObjectOperator if containsUnresolvedDeserializer(o.deserializers.map(_._1)) =>
- val deserializerToAttributes = o.deserializers.map {
- case (deserializer, attributes) => new TreeNodeRef(deserializer) -> attributes
- }.toMap
-
- o.transformExpressions {
- case expr => deserializerToAttributes.get(new TreeNodeRef(expr)).map { attributes =>
- resolveDeserializer(expr, attributes)
- }.getOrElse(expr)
- }
+ // Skips plan which contains deserializer expressions, as they should be resolved by another
+ // rule: ResolveDeserializer.
+ case plan if containsDeserializer(plan.expressions) => plan
case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
@@ -514,38 +559,6 @@ class Analyzer(
}
}
- private def containsUnresolvedDeserializer(exprs: Seq[Expression]): Boolean = {
- exprs.exists { expr =>
- !expr.resolved || expr.find(_.isInstanceOf[BoundReference]).isDefined
- }
- }
-
- def resolveDeserializer(
- deserializer: Expression,
- attributes: Seq[Attribute]): Expression = {
- val unbound = deserializer transform {
- case b: BoundReference => attributes(b.ordinal)
- }
-
- resolveExpression(unbound, LocalRelation(attributes), throws = true) transform {
- case n: NewInstance
- // If this is an inner class of another class, register the outer object in `OuterScopes`.
- // Note that static inner classes (e.g., inner classes within Scala objects) don't need
- // outer pointer registration.
- if n.outerPointer.isEmpty &&
- n.cls.isMemberClass &&
- !Modifier.isStatic(n.cls.getModifiers) =>
- val outer = OuterScopes.getOuterScope(n.cls)
- if (outer == null) {
- throw new AnalysisException(
- s"Unable to generate an encoder for inner class `${n.cls.getName}` without " +
- "access to the scope that this class was defined in.\n" +
- "Try moving this class out of its parent class.")
- }
- n.copy(outerPointer = Some(outer))
- }
- }
-
def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = {
expressions.map {
case a: Alias => Alias(a.child, a.name)(isGenerated = a.isGenerated)
@@ -611,6 +624,10 @@ class Analyzer(
}
}
+ private def containsDeserializer(exprs: Seq[Expression]): Boolean = {
+ exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined)
+ }
+
protected[sql] def resolveExpression(
expr: Expression,
plan: LogicalPlan,
@@ -692,13 +709,15 @@ class Analyzer(
* clause. This rule detects such queries and adds the required attributes to the original
* projection, so that they will be available during sorting. Another projection is added to
* remove these attributes after sorting.
+ *
+ * The HAVING clause could also used a grouping columns that is not presented in the SELECT.
*/
- object ResolveSortReferences extends Rule[LogicalPlan] {
+ object ResolveMissingReferences extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
case sa @ Sort(_, _, child: Aggregate) => sa
- case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
+ case s @ Sort(order, _, child) if child.resolved =>
try {
val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder])
val requiredAttrs = AttributeSet(newOrder).filter(_.resolved)
@@ -718,12 +737,32 @@ class Analyzer(
// in Sort
case ae: AnalysisException => s
}
+
+ case f @ Filter(cond, child) if child.resolved =>
+ try {
+ val newCond = resolveExpressionRecursively(cond, child)
+ val requiredAttrs = newCond.references.filter(_.resolved)
+ val missingAttrs = requiredAttrs -- child.outputSet
+ if (missingAttrs.nonEmpty) {
+ // Add missing attributes and then project them away.
+ Project(child.output,
+ Filter(newCond, addMissingAttr(child, missingAttrs)))
+ } else if (newCond != cond) {
+ f.copy(condition = newCond)
+ } else {
+ f
+ }
+ } catch {
+ // Attempting to resolve it might fail. When this happens, return the original plan.
+ // Users will see an AnalysisException for resolution failure of missing attributes
+ case ae: AnalysisException => f
+ }
}
/**
- * Add the missing attributes into projectList of Project/Window or aggregateExpressions of
- * Aggregate.
- */
+ * Add the missing attributes into projectList of Project/Window or aggregateExpressions of
+ * Aggregate.
+ */
private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = {
if (missingAttrs.isEmpty) {
return plan
@@ -755,9 +794,9 @@ class Analyzer(
}
/**
- * Resolve the expression on a specified logical plan and it's child (recursively), until
- * the expression is resolved or meet a non-unary node or Subquery.
- */
+ * Resolve the expression on a specified logical plan and it's child (recursively), until
+ * the expression is resolved or meet a non-unary node or Subquery.
+ */
@tailrec
private def resolveExpressionRecursively(expr: Expression, plan: LogicalPlan): Expression = {
val resolved = resolveExpression(expr, plan)
@@ -781,9 +820,18 @@ class Analyzer(
case q: LogicalPlan =>
q transformExpressions {
case u if !u.childrenResolved => u // Skip until children are resolved.
+ case u @ UnresolvedGenerator(name, children) =>
+ withPosition(u) {
+ catalog.lookupFunction(name, children) match {
+ case generator: Generator => generator
+ case other =>
+ failAnalysis(s"$name is expected to be a generator. However, " +
+ s"its class is ${other.getClass.getCanonicalName}, which is not a generator.")
+ }
+ }
case u @ UnresolvedFunction(name, children, isDistinct) =>
withPosition(u) {
- registry.lookupFunction(name, children) match {
+ catalog.lookupFunction(name, children) match {
// DISTINCT is not meaningful for a Max or a Min.
case max: Max if isDistinct =>
AggregateExpression(max, Complete, isDistinct = false)
@@ -863,27 +911,33 @@ class Analyzer(
if aggregate.resolved =>
// Try resolving the condition of the filter as though it is in the aggregate clause
- val aggregatedCondition =
- Aggregate(
- grouping,
- Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil,
- child)
- val resolvedOperator = execute(aggregatedCondition)
- def resolvedAggregateFilter =
- resolvedOperator
- .asInstanceOf[Aggregate]
- .aggregateExpressions.head
-
- // If resolution was successful and we see the filter has an aggregate in it, add it to
- // the original aggregate operator.
- if (resolvedOperator.resolved && containsAggregate(resolvedAggregateFilter)) {
- val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs
-
- Project(aggregate.output,
- Filter(resolvedAggregateFilter.toAttribute,
- aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
- } else {
- filter
+ try {
+ val aggregatedCondition =
+ Aggregate(
+ grouping,
+ Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil,
+ child)
+ val resolvedOperator = execute(aggregatedCondition)
+ def resolvedAggregateFilter =
+ resolvedOperator
+ .asInstanceOf[Aggregate]
+ .aggregateExpressions.head
+
+ // If resolution was successful and we see the filter has an aggregate in it, add it to
+ // the original aggregate operator.
+ if (resolvedOperator.resolved && containsAggregate(resolvedAggregateFilter)) {
+ val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs
+
+ Project(aggregate.output,
+ Filter(resolvedAggregateFilter.toAttribute,
+ aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
+ } else {
+ filter
+ }
+ } catch {
+ // Attempting to resolve in the aggregate can result in ambiguity. When this happens,
+ // just return the original plan.
+ case ae: AnalysisException => filter
}
case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved =>
@@ -947,11 +1001,8 @@ class Analyzer(
}
}
- private def isAggregateExpression(e: Expression): Boolean = {
- e.isInstanceOf[AggregateExpression] || e.isInstanceOf[Grouping] || e.isInstanceOf[GroupingID]
- }
def containsAggregate(condition: Expression): Boolean = {
- condition.find(isAggregateExpression).isDefined
+ condition.find(_.isInstanceOf[AggregateExpression]).isDefined
}
}
@@ -1146,11 +1197,11 @@ class Analyzer(
// Extract Windowed AggregateExpression
case we @ WindowExpression(
- AggregateExpression(function, mode, isDistinct),
+ ae @ AggregateExpression(function, _, _, _),
spec: WindowSpecDefinition) =>
val newChildren = function.children.map(extractExpr)
val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
- val newAgg = AggregateExpression(newFunction, mode, isDistinct)
+ val newAgg = ae.copy(aggregateFunction = newFunction)
seenWindowAggregates += newAgg
WindowExpression(newAgg, spec)
@@ -1386,8 +1437,8 @@ class Analyzer(
}
/**
- * Check and add order to [[AggregateWindowFunction]]s.
- */
+ * Check and add order to [[AggregateWindowFunction]]s.
+ */
object ResolveWindowOrder extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case logical: LogicalPlan => logical transformExpressions {
@@ -1444,7 +1495,7 @@ class Analyzer(
val projectList = joinType match {
case LeftOuter =>
leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true))
- case LeftSemi =>
+ case LeftExistence(_) =>
leftKeys ++ lUniqueOutput
case RightOuter =>
rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput
@@ -1463,7 +1514,94 @@ class Analyzer(
Project(projectList, Join(left, right, joinType, newCondition))
}
+ /**
+ * Replaces [[UnresolvedDeserializer]] with the deserialization expression that has been resolved
+ * to the given input attributes.
+ */
+ object ResolveDeserializer extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case p if !p.childrenResolved => p
+ case p if p.resolved => p
+ case p => p transformExpressions {
+ case UnresolvedDeserializer(deserializer, inputAttributes) =>
+ val inputs = if (inputAttributes.isEmpty) {
+ p.children.flatMap(_.output)
+ } else {
+ inputAttributes
+ }
+ val unbound = deserializer transform {
+ case b: BoundReference => inputs(b.ordinal)
+ }
+ resolveExpression(unbound, LocalRelation(inputs), throws = true)
+ }
+ }
+ }
+
+ /**
+ * Resolves [[NewInstance]] by finding and adding the outer scope to it if the object being
+ * constructed is an inner class.
+ */
+ object ResolveNewInstance extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case p if !p.childrenResolved => p
+ case p if p.resolved => p
+
+ case p => p transformExpressions {
+ case n: NewInstance if n.childrenResolved && !n.resolved =>
+ val outer = OuterScopes.getOuterScope(n.cls)
+ if (outer == null) {
+ throw new AnalysisException(
+ s"Unable to generate an encoder for inner class `${n.cls.getName}` without " +
+ "access to the scope that this class was defined in.\n" +
+ "Try moving this class out of its parent class.")
+ }
+ n.copy(outerPointer = Some(outer))
+ }
+ }
+ }
+
+ /**
+ * Replace the [[UpCast]] expression by [[Cast]], and throw exceptions if the cast may truncate.
+ */
+ object ResolveUpCast extends Rule[LogicalPlan] {
+ private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
+ throw new AnalysisException(s"Cannot up cast ${from.sql} from " +
+ s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" +
+ "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
+ "You can either add an explicit cast to the input data or choose a higher precision " +
+ "type of the field in the target object")
+ }
+
+ private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = {
+ val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from)
+ val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to)
+ toPrecedence > 0 && fromPrecedence > toPrecedence
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case p if !p.childrenResolved => p
+ case p if p.resolved => p
+
+ case p => p transformExpressions {
+ case u @ UpCast(child, _, _) if !child.resolved => u
+
+ case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match {
+ case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) =>
+ fail(child, to, walkedTypePath)
+ case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) =>
+ fail(child, to, walkedTypePath)
+ case (from, to) if illegalNumericPrecedence(from, to) =>
+ fail(child, to, walkedTypePath)
+ case (TimestampType, DateType) =>
+ fail(child, DateType, walkedTypePath)
+ case (StringType, to: NumericType) =>
+ fail(child, to, walkedTypePath)
+ case _ => Cast(child, dataType.asNullable)
+ }
+ }
+ }
+ }
}
/**
@@ -1477,8 +1615,8 @@ object EliminateSubqueryAliases extends Rule[LogicalPlan] {
}
/**
- * Removes [[Union]] operators from the plan if it just has one child.
- */
+ * Removes [[Union]] operators from the plan if it just has one child.
+ */
object EliminateUnions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Union(children) if children.size == 1 => children.head
@@ -1532,6 +1670,8 @@ object CleanupAliases extends Rule[LogicalPlan] {
// Operators that operate on objects should only have expressions from encoders, which should
// never have extra aliases.
case o: ObjectOperator => o
+ case d: DeserializeToObject => d
+ case s: SerializeFromObject => s
case other =>
var stop = false
@@ -1548,40 +1688,90 @@ object CleanupAliases extends Rule[LogicalPlan] {
}
/**
- * Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate.
+ * Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to
+ * figure out how many windows a time column can map to, we over-estimate the number of windows and
+ * filter out the rows where the time column is not inside the time window.
*/
-object ResolveUpCast extends Rule[LogicalPlan] {
- private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
- throw new AnalysisException(s"Cannot up cast ${from.sql} from " +
- s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" +
- "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
- "You can either add an explicit cast to the input data or choose a higher precision " +
- "type of the field in the target object")
- }
+object TimeWindowing extends Rule[LogicalPlan] {
+ import org.apache.spark.sql.catalyst.dsl.expressions._
- private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = {
- val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from)
- val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to)
- toPrecedence > 0 && fromPrecedence > toPrecedence
- }
+ private final val WINDOW_START = "start"
+ private final val WINDOW_END = "end"
- def apply(plan: LogicalPlan): LogicalPlan = {
- plan transformAllExpressions {
- case u @ UpCast(child, _, _) if !child.resolved => u
-
- case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match {
- case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) =>
- fail(child, to, walkedTypePath)
- case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) =>
- fail(child, to, walkedTypePath)
- case (from, to) if illegalNumericPrecedence(from, to) =>
- fail(child, to, walkedTypePath)
- case (TimestampType, DateType) =>
- fail(child, DateType, walkedTypePath)
- case (StringType, to: NumericType) =>
- fail(child, to, walkedTypePath)
- case _ => Cast(child, dataType.asNullable)
+ /**
+ * Generates the logical plan for generating window ranges on a timestamp column. Without
+ * knowing what the timestamp value is, it's non-trivial to figure out deterministically how many
+ * window ranges a timestamp will map to given all possible combinations of a window duration,
+ * slide duration and start time (offset). Therefore, we express and over-estimate the number of
+ * windows there may be, and filter the valid windows. We use last Project operator to group
+ * the window columns into a struct so they can be accessed as `window.start` and `window.end`.
+ *
+ * The windows are calculated as below:
+ * maxNumOverlapping <- ceil(windowDuration / slideDuration)
+ * for (i <- 0 until maxNumOverlapping)
+ * windowId <- ceil((timestamp - startTime) / slideDuration)
+ * windowStart <- windowId * slideDuration + (i - maxNumOverlapping) * slideDuration + startTime
+ * windowEnd <- windowStart + windowDuration
+ * return windowStart, windowEnd
+ *
+ * This behaves as follows for the given parameters for the time: 12:05. The valid windows are
+ * marked with a +, and invalid ones are marked with a x. The invalid ones are filtered using the
+ * Filter operator.
+ * window: 12m, slide: 5m, start: 0m :: window: 12m, slide: 5m, start: 2m
+ * 11:55 - 12:07 + 11:52 - 12:04 x
+ * 12:00 - 12:12 + 11:57 - 12:09 +
+ * 12:05 - 12:17 + 12:02 - 12:14 +
+ *
+ * @param plan The logical plan
+ * @return the logical plan that will generate the time windows using the Expand operator, with
+ * the Filter operator for correctness and Project for usability.
+ */
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case p: LogicalPlan if p.children.size == 1 =>
+ val child = p.children.head
+ val windowExpressions =
+ p.expressions.flatMap(_.collect { case t: TimeWindow => t }).distinct.toList // Not correct.
+
+ // Only support a single window expression for now
+ if (windowExpressions.size == 1 &&
+ windowExpressions.head.timeColumn.resolved &&
+ windowExpressions.head.checkInputDataTypes().isSuccess) {
+ val window = windowExpressions.head
+ val windowAttr = AttributeReference("window", window.dataType)()
+
+ val maxNumOverlapping = math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt
+ val windows = Seq.tabulate(maxNumOverlapping + 1) { i =>
+ val windowId = Ceil((PreciseTimestamp(window.timeColumn) - window.startTime) /
+ window.slideDuration)
+ val windowStart = (windowId + i - maxNumOverlapping) *
+ window.slideDuration + window.startTime
+ val windowEnd = windowStart + window.windowDuration
+
+ CreateNamedStruct(
+ Literal(WINDOW_START) :: windowStart ::
+ Literal(WINDOW_END) :: windowEnd :: Nil)
+ }
+
+ val projections = windows.map(_ +: p.children.head.output)
+
+ val filterExpr =
+ window.timeColumn >= windowAttr.getField(WINDOW_START) &&
+ window.timeColumn < windowAttr.getField(WINDOW_END)
+
+ val expandedPlan =
+ Filter(filterExpr,
+ Expand(projections, windowAttr +: child.output, child))
+
+ val substitutedPlan = p transformExpressions {
+ case t: TimeWindow => windowAttr
+ }
+
+ substitutedPlan.withNewChildren(expandedPlan :: Nil)
+ } else if (windowExpressions.size > 1) {
+ p.failAnalysis("Multiple time window expressions would result in a cartesian product " +
+ "of rows, therefore they are not currently not supported.")
+ } else {
+ p // Return unchanged. Analyzer will throw exception later
}
- }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 1d1e892e32..d6a8c3eec8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -52,7 +52,7 @@ trait CheckAnalysis {
case p if p.analyzed => // Skip already analyzed sub-plans
case u: UnresolvedRelation =>
- u.failAnalysis(s"Table not found: ${u.tableIdentifier}")
+ u.failAnalysis(s"Table or View not found: ${u.tableIdentifier}")
case operator: LogicalPlan =>
operator transformExpressionsUp {
@@ -76,7 +76,7 @@ trait CheckAnalysis {
case g: GroupingID =>
failAnalysis(s"grouping_id() can only be used with GroupingSets/Cube/Rollup")
- case w @ WindowExpression(AggregateExpression(_, _, true), _) =>
+ case w @ WindowExpression(AggregateExpression(_, _, true, _), _) =>
failAnalysis(s"Distinct window functions are not supported: $w")
case w @ WindowExpression(_: OffsetWindowFunction, WindowSpecDefinition(_, order,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index f584a4b73a..f2abf136da 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -45,6 +45,19 @@ trait FunctionRegistry {
/* Get the class of the registered function by specified name. */
def lookupFunction(name: String): Option[ExpressionInfo]
+
+ /* Get the builder of the registered function by specified name. */
+ def lookupFunctionBuilder(name: String): Option[FunctionBuilder]
+
+ /** Drop a function and return whether the function existed. */
+ def dropFunction(name: String): Boolean
+
+ /** Checks if a function with a given name exists. */
+ def functionExists(name: String): Boolean = lookupFunction(name).isDefined
+
+ /** Clear all registered functions. */
+ def clear(): Unit
+
}
class SimpleFunctionRegistry extends FunctionRegistry {
@@ -76,6 +89,18 @@ class SimpleFunctionRegistry extends FunctionRegistry {
functionBuilders.get(name).map(_._1)
}
+ override def lookupFunctionBuilder(name: String): Option[FunctionBuilder] = synchronized {
+ functionBuilders.get(name).map(_._2)
+ }
+
+ override def dropFunction(name: String): Boolean = synchronized {
+ functionBuilders.remove(name).isDefined
+ }
+
+ override def clear(): Unit = {
+ functionBuilders.clear()
+ }
+
def copy(): SimpleFunctionRegistry = synchronized {
val registry = new SimpleFunctionRegistry
functionBuilders.iterator.foreach { case (name, (info, builder)) =>
@@ -106,6 +131,19 @@ object EmptyFunctionRegistry extends FunctionRegistry {
override def lookupFunction(name: String): Option[ExpressionInfo] = {
throw new UnsupportedOperationException
}
+
+ override def lookupFunctionBuilder(name: String): Option[FunctionBuilder] = {
+ throw new UnsupportedOperationException
+ }
+
+ override def dropFunction(name: String): Boolean = {
+ throw new UnsupportedOperationException
+ }
+
+ override def clear(): Unit = {
+ throw new UnsupportedOperationException
+ }
+
}
@@ -133,6 +171,7 @@ object FunctionRegistry {
expression[Rand]("rand"),
expression[Randn]("randn"),
expression[CreateStruct]("struct"),
+ expression[CaseWhen]("when"),
// math functions
expression[Acos]("acos"),
@@ -179,6 +218,12 @@ object FunctionRegistry {
expression[Tan]("tan"),
expression[Tanh]("tanh"),
+ expression[Add]("+"),
+ expression[Subtract]("-"),
+ expression[Multiply]("*"),
+ expression[Divide]("/"),
+ expression[Remainder]("%"),
+
// aggregate functions
expression[HyperLogLogPlusPlus]("approx_count_distinct"),
expression[Average]("avg"),
@@ -219,6 +264,7 @@ object FunctionRegistry {
expression[Lower]("lcase"),
expression[Length]("length"),
expression[Levenshtein]("levenshtein"),
+ expression[Like]("like"),
expression[Lower]("lower"),
expression[StringLocate]("locate"),
expression[StringLPad]("lpad"),
@@ -229,6 +275,7 @@ object FunctionRegistry {
expression[RegExpReplace]("regexp_replace"),
expression[StringRepeat]("repeat"),
expression[StringReverse]("reverse"),
+ expression[RLike]("rlike"),
expression[StringRPad]("rpad"),
expression[StringTrimRight]("rtrim"),
expression[SoundEx]("soundex"),
@@ -273,6 +320,7 @@ object FunctionRegistry {
expression[UnixTimestamp]("unix_timestamp"),
expression[WeekOfYear]("weekofyear"),
expression[Year]("year"),
+ expression[TimeWindow]("window"),
// collection functions
expression[ArrayContains]("array_contains"),
@@ -304,7 +352,29 @@ object FunctionRegistry {
expression[NTile]("ntile"),
expression[Rank]("rank"),
expression[DenseRank]("dense_rank"),
- expression[PercentRank]("percent_rank")
+ expression[PercentRank]("percent_rank"),
+
+ // predicates
+ expression[And]("and"),
+ expression[In]("in"),
+ expression[Not]("not"),
+ expression[Or]("or"),
+
+ expression[EqualNullSafe]("<=>"),
+ expression[EqualTo]("="),
+ expression[EqualTo]("=="),
+ expression[GreaterThan](">"),
+ expression[GreaterThanOrEqual](">="),
+ expression[LessThan]("<"),
+ expression[LessThanOrEqual]("<="),
+ expression[Not]("!"),
+
+ // bitwise
+ expression[BitwiseAnd]("&"),
+ expression[BitwiseNot]("~"),
+ expression[BitwiseOr]("|"),
+ expression[BitwiseXor]("^")
+
)
val builtin: SimpleFunctionRegistry = {
@@ -337,7 +407,10 @@ object FunctionRegistry {
}
Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match {
case Success(e) => e
- case Failure(e) => throw new AnalysisException(e.getMessage)
+ case Failure(e) =>
+ // the exception is an invocation exception. To get a meaningful message, we need the
+ // cause.
+ throw new AnalysisException(e.getCause.getMessage)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala
index e9f04eecf8..5e18316c94 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.analysis
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec
@@ -24,29 +25,13 @@ import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec
* Thrown by a catalog when an item cannot be found. The analyzer will rethrow the exception
* as an [[org.apache.spark.sql.AnalysisException]] with the correct position information.
*/
-abstract class NoSuchItemException extends Exception {
- override def getMessage: String
-}
+class NoSuchDatabaseException(db: String) extends AnalysisException(s"Database $db not found")
-class NoSuchDatabaseException(db: String) extends NoSuchItemException {
- override def getMessage: String = s"Database $db not found"
-}
+class NoSuchTableException(db: String, table: String)
+ extends AnalysisException(s"Table or View $table not found in database $db")
-class NoSuchTableException(db: String, table: String) extends NoSuchItemException {
- override def getMessage: String = s"Table $table not found in database $db"
-}
+class NoSuchPartitionException(db: String, table: String, spec: TablePartitionSpec) extends
+ AnalysisException(s"Partition not found in table $table database $db:\n" + spec.mkString("\n"))
-class NoSuchPartitionException(
- db: String,
- table: String,
- spec: TablePartitionSpec)
- extends NoSuchItemException {
-
- override def getMessage: String = {
- s"Partition not found in table $table database $db:\n" + spec.mkString("\n")
- }
-}
-
-class NoSuchFunctionException(db: String, func: String) extends NoSuchItemException {
- override def getMessage: String = s"Function $func not found in database $db"
-}
+class NoSuchFunctionException(db: String, func: String)
+ extends AnalysisException(s"Function $func not found in database $db")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index e73d367a73..4ec43aba02 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -18,9 +18,9 @@
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.{errors, TableIdentifier}
+import org.apache.spark.sql.catalyst.{errors, InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan}
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.util.quoteIdentifier
@@ -133,6 +133,33 @@ object UnresolvedAttribute {
}
}
+/**
+ * Represents an unresolved generator, which will be created by the parser for
+ * the [[org.apache.spark.sql.catalyst.plans.logical.Generate]] operator.
+ * The analyzer will resolve this generator.
+ */
+case class UnresolvedGenerator(name: String, children: Seq[Expression]) extends Generator {
+
+ override def elementTypes: Seq[(DataType, Boolean, String)] =
+ throw new UnresolvedException(this, "elementTypes")
+ override def dataType: DataType = throw new UnresolvedException(this, "dataType")
+ override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
+ override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
+ override lazy val resolved = false
+
+ override def prettyName: String = name
+ override def toString: String = s"'$name(${children.mkString(", ")})"
+
+ override def eval(input: InternalRow = null): TraversableOnce[InternalRow] =
+ throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
+
+ override protected def genCode(ctx: CodegenContext, ev: ExprCode): String =
+ throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
+
+ override def terminate(): TraversableOnce[InternalRow] =
+ throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
+}
+
case class UnresolvedFunction(
name: String,
children: Seq[Expression],
@@ -307,3 +334,25 @@ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None)
override lazy val resolved = false
}
+
+/**
+ * Holds the deserializer expression and the attributes that are available during the resolution
+ * for it. Deserializer expression is a special kind of expression that is not always resolved by
+ * children output, but by given attributes, e.g. the `keyDeserializer` in `MapGroups` should be
+ * resolved by `groupingAttributes` instead of children output.
+ *
+ * @param deserializer The unresolved deserializer expression
+ * @param inputAttributes The input attributes used to resolve deserializer expression, can be empty
+ * if we want to resolve deserializer by children output.
+ */
+case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute] = Nil)
+ extends UnaryExpression with Unevaluable with NonSQLExpression {
+ // The input attributes used to resolve deserializer expression must be all resolved.
+ require(inputAttributes.forall(_.resolved), "Input attributes must all be resolved.")
+
+ override def child: Expression = deserializer
+ override def dataType: DataType = throw new UnresolvedException(this, "dataType")
+ override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
+ override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
+ override lazy val resolved = false
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
index e216fa5528..f8a6fb74cc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
-
+import org.apache.spark.sql.catalyst.util.StringUtils
/**
* An in-memory (ephemeral) implementation of the system catalog.
@@ -47,16 +47,6 @@ class InMemoryCatalog extends ExternalCatalog {
// Database name -> description
private val catalog = new scala.collection.mutable.HashMap[String, DatabaseDesc]
- private def filterPattern(names: Seq[String], pattern: String): Seq[String] = {
- val regex = pattern.replaceAll("\\*", ".*").r
- names.filter { funcName => regex.pattern.matcher(funcName).matches() }
- }
-
- private def functionExists(db: String, funcName: String): Boolean = {
- requireDbExists(db)
- catalog(db).functions.contains(funcName)
- }
-
private def partitionExists(db: String, table: String, spec: TablePartitionSpec): Boolean = {
requireTableExists(db, table)
catalog(db).tables(table).partitions.contains(spec)
@@ -72,7 +62,7 @@ class InMemoryCatalog extends ExternalCatalog {
private def requireTableExists(db: String, table: String): Unit = {
if (!tableExists(db, table)) {
throw new AnalysisException(
- s"Table not found: '$table' does not exist in database '$db'")
+ s"Table or View not found: '$table' does not exist in database '$db'")
}
}
@@ -141,7 +131,7 @@ class InMemoryCatalog extends ExternalCatalog {
}
override def listDatabases(pattern: String): Seq[String] = synchronized {
- filterPattern(listDatabases(), pattern)
+ StringUtils.filterPattern(listDatabases(), pattern)
}
override def setCurrentDatabase(db: String): Unit = { /* no-op */ }
@@ -155,7 +145,7 @@ class InMemoryCatalog extends ExternalCatalog {
tableDefinition: CatalogTable,
ignoreIfExists: Boolean): Unit = synchronized {
requireDbExists(db)
- val table = tableDefinition.name.table
+ val table = tableDefinition.identifier.table
if (tableExists(db, table)) {
if (!ignoreIfExists) {
throw new AnalysisException(s"Table '$table' already exists in database '$db'")
@@ -174,7 +164,7 @@ class InMemoryCatalog extends ExternalCatalog {
catalog(db).tables.remove(table)
} else {
if (!ignoreIfNotExists) {
- throw new AnalysisException(s"Table '$table' does not exist in database '$db'")
+ throw new AnalysisException(s"Table or View '$table' does not exist in database '$db'")
}
}
}
@@ -182,14 +172,14 @@ class InMemoryCatalog extends ExternalCatalog {
override def renameTable(db: String, oldName: String, newName: String): Unit = synchronized {
requireTableExists(db, oldName)
val oldDesc = catalog(db).tables(oldName)
- oldDesc.table = oldDesc.table.copy(name = TableIdentifier(newName, Some(db)))
+ oldDesc.table = oldDesc.table.copy(identifier = TableIdentifier(newName, Some(db)))
catalog(db).tables.put(newName, oldDesc)
catalog(db).tables.remove(oldName)
}
override def alterTable(db: String, tableDefinition: CatalogTable): Unit = synchronized {
- requireTableExists(db, tableDefinition.name.table)
- catalog(db).tables(tableDefinition.name.table).table = tableDefinition
+ requireTableExists(db, tableDefinition.identifier.table)
+ catalog(db).tables(tableDefinition.identifier.table).table = tableDefinition
}
override def getTable(db: String, table: String): CatalogTable = synchronized {
@@ -197,6 +187,10 @@ class InMemoryCatalog extends ExternalCatalog {
catalog(db).tables(table).table
}
+ override def getTableOption(db: String, table: String): Option[CatalogTable] = synchronized {
+ if (!tableExists(db, table)) None else Option(catalog(db).tables(table).table)
+ }
+
override def tableExists(db: String, table: String): Boolean = synchronized {
requireDbExists(db)
catalog(db).tables.contains(table)
@@ -208,7 +202,7 @@ class InMemoryCatalog extends ExternalCatalog {
}
override def listTables(db: String, pattern: String): Seq[String] = synchronized {
- filterPattern(listTables(db), pattern)
+ StringUtils.filterPattern(listTables(db), pattern)
}
// --------------------------------------------------------------------------
@@ -296,10 +290,10 @@ class InMemoryCatalog extends ExternalCatalog {
override def createFunction(db: String, func: CatalogFunction): Unit = synchronized {
requireDbExists(db)
- if (functionExists(db, func.name.funcName)) {
+ if (functionExists(db, func.identifier.funcName)) {
throw new AnalysisException(s"Function '$func' already exists in '$db' database")
} else {
- catalog(db).functions.put(func.name.funcName, func)
+ catalog(db).functions.put(func.identifier.funcName, func)
}
}
@@ -310,24 +304,24 @@ class InMemoryCatalog extends ExternalCatalog {
override def renameFunction(db: String, oldName: String, newName: String): Unit = synchronized {
requireFunctionExists(db, oldName)
- val newFunc = getFunction(db, oldName).copy(name = FunctionIdentifier(newName, Some(db)))
+ val newFunc = getFunction(db, oldName).copy(identifier = FunctionIdentifier(newName, Some(db)))
catalog(db).functions.remove(oldName)
catalog(db).functions.put(newName, newFunc)
}
- override def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = synchronized {
- requireFunctionExists(db, funcDefinition.name.funcName)
- catalog(db).functions.put(funcDefinition.name.funcName, funcDefinition)
- }
-
override def getFunction(db: String, funcName: String): CatalogFunction = synchronized {
requireFunctionExists(db, funcName)
catalog(db).functions(funcName)
}
+ override def functionExists(db: String, funcName: String): Boolean = {
+ requireDbExists(db)
+ catalog(db).functions.contains(funcName)
+ }
+
override def listFunctions(db: String, pattern: String): Seq[String] = synchronized {
requireDbExists(db)
- filterPattern(catalog(db).functions.keysIterator.toSeq, pattern)
+ StringUtils.filterPattern(catalog(db).functions.keysIterator.toSeq, pattern)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index 34265faa74..34e1cb7315 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -17,30 +17,47 @@
package org.apache.spark.sql.catalyst.catalog
-import java.util.concurrent.ConcurrentHashMap
+import java.io.File
-import scala.collection.JavaConverters._
+import scala.collection.mutable
+import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf}
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchFunctionException, SimpleFunctionRegistry}
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
-
+import org.apache.spark.sql.catalyst.util.StringUtils
/**
* An internal catalog that is used by a Spark Session. This internal catalog serves as a
* proxy to the underlying metastore (e.g. Hive Metastore) and it also manages temporary
* tables and functions of the Spark Session that it belongs to.
+ *
+ * This class is not thread-safe.
*/
-class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
+class SessionCatalog(
+ externalCatalog: ExternalCatalog,
+ functionResourceLoader: FunctionResourceLoader,
+ functionRegistry: FunctionRegistry,
+ conf: CatalystConf) extends Logging {
import ExternalCatalog._
+ def this(
+ externalCatalog: ExternalCatalog,
+ functionRegistry: FunctionRegistry,
+ conf: CatalystConf) {
+ this(externalCatalog, DummyFunctionResourceLoader, functionRegistry, conf)
+ }
+
+ // For testing only.
def this(externalCatalog: ExternalCatalog) {
- this(externalCatalog, new SimpleCatalystConf(true))
+ this(externalCatalog, new SimpleFunctionRegistry, new SimpleCatalystConf(true))
}
- protected[this] val tempTables = new ConcurrentHashMap[String, LogicalPlan]
- protected[this] val tempFunctions = new ConcurrentHashMap[String, CatalogFunction]
+ protected[this] val tempTables = new mutable.HashMap[String, LogicalPlan]
// Note: we track current database here because certain operations do not explicitly
// specify the database (e.g. DROP TABLE my_table). In these cases we must first
@@ -79,7 +96,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
externalCatalog.alterDatabase(dbDefinition)
}
- def getDatabase(db: String): CatalogDatabase = {
+ def getDatabaseMetadata(db: String): CatalogDatabase = {
externalCatalog.getDatabase(db)
}
@@ -104,6 +121,10 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
currentDb = db
}
+ def getDefaultDBPath(db: String): String = {
+ System.getProperty("java.io.tmpdir") + File.separator + db + ".db"
+ }
+
// ----------------------------------------------------------------------------
// Tables
// ----------------------------------------------------------------------------
@@ -122,9 +143,9 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
* If no such database is specified, create it in the current database.
*/
def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = {
- val db = tableDefinition.name.database.getOrElse(currentDb)
- val table = formatTableName(tableDefinition.name.table)
- val newTableDefinition = tableDefinition.copy(name = TableIdentifier(table, Some(db)))
+ val db = tableDefinition.identifier.database.getOrElse(currentDb)
+ val table = formatTableName(tableDefinition.identifier.table)
+ val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db)))
externalCatalog.createTable(db, newTableDefinition, ignoreIfExists)
}
@@ -138,22 +159,34 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
* this becomes a no-op.
*/
def alterTable(tableDefinition: CatalogTable): Unit = {
- val db = tableDefinition.name.database.getOrElse(currentDb)
- val table = formatTableName(tableDefinition.name.table)
- val newTableDefinition = tableDefinition.copy(name = TableIdentifier(table, Some(db)))
+ val db = tableDefinition.identifier.database.getOrElse(currentDb)
+ val table = formatTableName(tableDefinition.identifier.table)
+ val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db)))
externalCatalog.alterTable(db, newTableDefinition)
}
/**
* Retrieve the metadata of an existing metastore table.
* If no database is specified, assume the table is in the current database.
+ * If the specified table is not found in the database then an [[AnalysisException]] is thrown.
*/
- def getTable(name: TableIdentifier): CatalogTable = {
+ def getTableMetadata(name: TableIdentifier): CatalogTable = {
val db = name.database.getOrElse(currentDb)
val table = formatTableName(name.table)
externalCatalog.getTable(db, table)
}
+ /**
+ * Retrieve the metadata of an existing metastore table.
+ * If no database is specified, assume the table is in the current database.
+ * If the specified table is not found in the database then return None if it doesn't exist.
+ */
+ def getTableMetadataOption(name: TableIdentifier): Option[CatalogTable] = {
+ val db = name.database.getOrElse(currentDb)
+ val table = formatTableName(name.table)
+ externalCatalog.getTableOption(db, table)
+ }
+
// -------------------------------------------------------------
// | Methods that interact with temporary and metastore tables |
// -------------------------------------------------------------
@@ -164,9 +197,9 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
def createTempTable(
name: String,
tableDefinition: LogicalPlan,
- ignoreIfExists: Boolean): Unit = {
+ overrideIfExists: Boolean): Unit = {
val table = formatTableName(name)
- if (tempTables.containsKey(table) && !ignoreIfExists) {
+ if (tempTables.contains(table) && !overrideIfExists) {
throw new AnalysisException(s"Temporary table '$name' already exists.")
}
tempTables.put(table, tableDefinition)
@@ -188,10 +221,11 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
val db = oldName.database.getOrElse(currentDb)
val oldTableName = formatTableName(oldName.table)
val newTableName = formatTableName(newName.table)
- if (oldName.database.isDefined || !tempTables.containsKey(oldTableName)) {
+ if (oldName.database.isDefined || !tempTables.contains(oldTableName)) {
externalCatalog.renameTable(db, oldTableName, newTableName)
} else {
- val table = tempTables.remove(oldTableName)
+ val table = tempTables(oldTableName)
+ tempTables.remove(oldTableName)
tempTables.put(newTableName, table)
}
}
@@ -206,8 +240,14 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
def dropTable(name: TableIdentifier, ignoreIfNotExists: Boolean): Unit = {
val db = name.database.getOrElse(currentDb)
val table = formatTableName(name.table)
- if (name.database.isDefined || !tempTables.containsKey(table)) {
- externalCatalog.dropTable(db, table, ignoreIfNotExists)
+ if (name.database.isDefined || !tempTables.contains(table)) {
+ // When ignoreIfNotExists is false, no exception is issued when the table does not exist.
+ // Instead, log it as an error message.
+ if (externalCatalog.tableExists(db, table)) {
+ externalCatalog.dropTable(db, table, ignoreIfNotExists = true)
+ } else if (!ignoreIfNotExists) {
+ logError(s"Table or View '${name.quotedString}' does not exist")
+ }
} else {
tempTables.remove(table)
}
@@ -224,11 +264,11 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
val db = name.database.getOrElse(currentDb)
val table = formatTableName(name.table)
val relation =
- if (name.database.isDefined || !tempTables.containsKey(table)) {
+ if (name.database.isDefined || !tempTables.contains(table)) {
val metadata = externalCatalog.getTable(db, table)
CatalogRelation(db, metadata, alias)
} else {
- tempTables.get(table)
+ tempTables(table)
}
val qualifiedTable = SubqueryAlias(table, relation)
// If an alias was specified by the lookup, wrap the plan in a subquery so that
@@ -247,7 +287,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
def tableExists(name: TableIdentifier): Boolean = {
val db = name.database.getOrElse(currentDb)
val table = formatTableName(name.table)
- if (name.database.isDefined || !tempTables.containsKey(table)) {
+ if (name.database.isDefined || !tempTables.contains(table)) {
externalCatalog.tableExists(db, table)
} else {
true // it's a temporary table
@@ -255,6 +295,16 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
}
/**
+ * Return whether a table with the specified name is a temporary table.
+ *
+ * Note: The temporary table cache is checked only when database is not
+ * explicitly specified.
+ */
+ def isTemporaryTable(name: TableIdentifier): Boolean = {
+ name.database.isEmpty && tempTables.contains(formatTableName(name.table))
+ }
+
+ /**
* List all tables in the specified database, including temporary tables.
*/
def listTables(db: String): Seq[TableIdentifier] = listTables(db, "*")
@@ -265,19 +315,24 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
def listTables(db: String, pattern: String): Seq[TableIdentifier] = {
val dbTables =
externalCatalog.listTables(db, pattern).map { t => TableIdentifier(t, Some(db)) }
- val regex = pattern.replaceAll("\\*", ".*").r
- val _tempTables = tempTables.keys().asScala
- .filter { t => regex.pattern.matcher(t).matches() }
+ val _tempTables = StringUtils.filterPattern(tempTables.keys.toSeq, pattern)
.map { t => TableIdentifier(t) }
dbTables ++ _tempTables
}
+ // TODO: It's strange that we have both refresh and invalidate here.
+
/**
* Refresh the cache entry for a metastore table, if any.
*/
def refreshTable(name: TableIdentifier): Unit = { /* no-op */ }
/**
+ * Invalidate the cache entry for a metastore table, if any.
+ */
+ def invalidateTable(name: TableIdentifier): Unit = { /* no-op */ }
+
+ /**
* Drop all existing temporary tables.
* For testing only.
*/
@@ -290,7 +345,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
* For testing only.
*/
private[catalog] def getTempTable(name: String): Option[LogicalPlan] = {
- Option(tempTables.get(name))
+ tempTables.get(name)
}
// ----------------------------------------------------------------------------
@@ -398,36 +453,57 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
* Create a metastore function in the database specified in `funcDefinition`.
* If no such database is specified, create it in the current database.
*/
- def createFunction(funcDefinition: CatalogFunction): Unit = {
- val db = funcDefinition.name.database.getOrElse(currentDb)
- val newFuncDefinition = funcDefinition.copy(
- name = FunctionIdentifier(funcDefinition.name.funcName, Some(db)))
- externalCatalog.createFunction(db, newFuncDefinition)
+ def createFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = {
+ val db = funcDefinition.identifier.database.getOrElse(currentDb)
+ val identifier = FunctionIdentifier(funcDefinition.identifier.funcName, Some(db))
+ val newFuncDefinition = funcDefinition.copy(identifier = identifier)
+ if (!functionExists(identifier)) {
+ externalCatalog.createFunction(db, newFuncDefinition)
+ } else if (!ignoreIfExists) {
+ throw new AnalysisException(s"function '$identifier' already exists in database '$db'")
+ }
}
/**
* Drop a metastore function.
* If no database is specified, assume the function is in the current database.
*/
- def dropFunction(name: FunctionIdentifier): Unit = {
+ def dropFunction(name: FunctionIdentifier, ignoreIfNotExists: Boolean): Unit = {
val db = name.database.getOrElse(currentDb)
- externalCatalog.dropFunction(db, name.funcName)
+ val identifier = name.copy(database = Some(db))
+ if (functionExists(identifier)) {
+ // TODO: registry should just take in FunctionIdentifier for type safety
+ if (functionRegistry.functionExists(identifier.unquotedString)) {
+ // If we have loaded this function into the FunctionRegistry,
+ // also drop it from there.
+ // For a permanent function, because we loaded it to the FunctionRegistry
+ // when it's first used, we also need to drop it from the FunctionRegistry.
+ functionRegistry.dropFunction(identifier.unquotedString)
+ }
+ externalCatalog.dropFunction(db, name.funcName)
+ } else if (!ignoreIfNotExists) {
+ throw new AnalysisException(s"function '$identifier' does not exist in database '$db'")
+ }
}
/**
- * Alter a metastore function whose name that matches the one specified in `funcDefinition`.
- *
- * If no database is specified in `funcDefinition`, assume the function is in the
- * current database.
+ * Retrieve the metadata of a metastore function.
*
- * Note: If the underlying implementation does not support altering a certain field,
- * this becomes a no-op.
+ * If a database is specified in `name`, this will return the function in that database.
+ * If no database is specified, this will return the function in the current database.
*/
- def alterFunction(funcDefinition: CatalogFunction): Unit = {
- val db = funcDefinition.name.database.getOrElse(currentDb)
- val newFuncDefinition = funcDefinition.copy(
- name = FunctionIdentifier(funcDefinition.name.funcName, Some(db)))
- externalCatalog.alterFunction(db, newFuncDefinition)
+ def getFunctionMetadata(name: FunctionIdentifier): CatalogFunction = {
+ val db = name.database.getOrElse(currentDb)
+ externalCatalog.getFunction(db, name.funcName)
+ }
+
+ /**
+ * Check if the specified function exists.
+ */
+ def functionExists(name: FunctionIdentifier): Boolean = {
+ val db = name.database.getOrElse(currentDb)
+ functionRegistry.functionExists(name.unquotedString) ||
+ externalCatalog.functionExists(db, name.funcName)
}
// ----------------------------------------------------------------
@@ -435,17 +511,40 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
// ----------------------------------------------------------------
/**
+ * Construct a [[FunctionBuilder]] based on the provided class that represents a function.
+ *
+ * This performs reflection to decide what type of [[Expression]] to return in the builder.
+ */
+ private[sql] def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = {
+ // TODO: at least support UDAFs here
+ throw new UnsupportedOperationException("Use sqlContext.udf.register(...) instead.")
+ }
+
+ /**
+ * Loads resources such as JARs and Files for a function. Every resource is represented
+ * by a tuple (resource type, resource uri).
+ */
+ def loadFunctionResources(resources: Seq[(String, String)]): Unit = {
+ resources.foreach { case (resourceType, uri) =>
+ val functionResource =
+ FunctionResource(FunctionResourceType.fromString(resourceType.toLowerCase), uri)
+ functionResourceLoader.loadResource(functionResource)
+ }
+ }
+
+ /**
* Create a temporary function.
* This assumes no database is specified in `funcDefinition`.
*/
- def createTempFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = {
- require(funcDefinition.name.database.isEmpty,
- "attempted to create a temporary function while specifying a database")
- val name = funcDefinition.name.funcName
- if (tempFunctions.containsKey(name) && !ignoreIfExists) {
+ def createTempFunction(
+ name: String,
+ info: ExpressionInfo,
+ funcDefinition: FunctionBuilder,
+ ignoreIfExists: Boolean): Unit = {
+ if (functionRegistry.lookupFunctionBuilder(name).isDefined && !ignoreIfExists) {
throw new AnalysisException(s"Temporary function '$name' already exists.")
}
- tempFunctions.put(name, funcDefinition)
+ functionRegistry.registerFunction(name, info, funcDefinition)
}
/**
@@ -455,53 +554,71 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
// Hive has DROP FUNCTION and DROP TEMPORARY FUNCTION. We may want to consolidate
// dropFunction and dropTempFunction.
def dropTempFunction(name: String, ignoreIfNotExists: Boolean): Unit = {
- if (!tempFunctions.containsKey(name) && !ignoreIfNotExists) {
+ if (!functionRegistry.dropFunction(name) && !ignoreIfNotExists) {
throw new AnalysisException(
s"Temporary function '$name' cannot be dropped because it does not exist!")
}
- tempFunctions.remove(name)
+ }
+
+ protected def failFunctionLookup(name: String): Nothing = {
+ throw new AnalysisException(s"Undefined function: $name. This function is " +
+ s"neither a registered temporary function nor " +
+ s"a permanent function registered in the database $currentDb.")
}
/**
- * Rename a function.
+ * Return an [[Expression]] that represents the specified function, assuming it exists.
*
- * If a database is specified in `oldName`, this will rename the function in that database.
- * If no database is specified, this will first attempt to rename a temporary function with
- * the same name, then, if that does not exist, rename the function in the current database.
+ * For a temporary function or a permanent function that has been loaded,
+ * this method will simply lookup the function through the
+ * FunctionRegistry and create an expression based on the builder.
*
- * This assumes the database specified in `oldName` matches the one specified in `newName`.
+ * For a permanent function that has not been loaded, we will first fetch its metadata
+ * from the underlying external catalog. Then, we will load all resources associated
+ * with this function (i.e. jars and files). Finally, we create a function builder
+ * based on the function class and put the builder into the FunctionRegistry.
+ * The name of this function in the FunctionRegistry will be `databaseName.functionName`.
*/
- def renameFunction(oldName: FunctionIdentifier, newName: FunctionIdentifier): Unit = {
- if (oldName.database != newName.database) {
- throw new AnalysisException("rename does not support moving functions across databases")
- }
- val db = oldName.database.getOrElse(currentDb)
- if (oldName.database.isDefined || !tempFunctions.containsKey(oldName.funcName)) {
- externalCatalog.renameFunction(db, oldName.funcName, newName.funcName)
+ def lookupFunction(name: String, children: Seq[Expression]): Expression = {
+ // TODO: Right now, the name can be qualified or not qualified.
+ // It will be better to get a FunctionIdentifier.
+ // TODO: Right now, we assume that name is not qualified!
+ val qualifiedName = FunctionIdentifier(name, Some(currentDb)).unquotedString
+ if (functionRegistry.functionExists(name)) {
+ // This function has been already loaded into the function registry.
+ functionRegistry.lookupFunction(name, children)
+ } else if (functionRegistry.functionExists(qualifiedName)) {
+ // This function has been already loaded into the function registry.
+ // Unlike the above block, we find this function by using the qualified name.
+ functionRegistry.lookupFunction(qualifiedName, children)
} else {
- val func = tempFunctions.remove(oldName.funcName)
- val newFunc = func.copy(name = func.name.copy(funcName = newName.funcName))
- tempFunctions.put(newName.funcName, newFunc)
+ // The function has not been loaded to the function registry, which means
+ // that the function is a permanent function (if it actually has been registered
+ // in the metastore). We need to first put the function in the FunctionRegistry.
+ val catalogFunction = try {
+ externalCatalog.getFunction(currentDb, name)
+ } catch {
+ case e: AnalysisException => failFunctionLookup(name)
+ case e: NoSuchFunctionException => failFunctionLookup(name)
+ }
+ loadFunctionResources(catalogFunction.resources)
+ // Please note that qualifiedName is provided by the user. However,
+ // catalogFunction.identifier.unquotedString is returned by the underlying
+ // catalog. So, it is possible that qualifiedName is not exactly the same as
+ // catalogFunction.identifier.unquotedString (difference is on case-sensitivity).
+ // At here, we preserve the input from the user.
+ val info = new ExpressionInfo(catalogFunction.className, qualifiedName)
+ val builder = makeFunctionBuilder(qualifiedName, catalogFunction.className)
+ createTempFunction(qualifiedName, info, builder, ignoreIfExists = false)
+ // Now, we need to create the Expression.
+ functionRegistry.lookupFunction(qualifiedName, children)
}
}
/**
- * Retrieve the metadata of an existing function.
- *
- * If a database is specified in `name`, this will return the function in that database.
- * If no database is specified, this will first attempt to return a temporary function with
- * the same name, then, if that does not exist, return the function in the current database.
+ * List all functions in the specified database, including temporary functions.
*/
- def getFunction(name: FunctionIdentifier): CatalogFunction = {
- val db = name.database.getOrElse(currentDb)
- if (name.database.isDefined || !tempFunctions.containsKey(name.funcName)) {
- externalCatalog.getFunction(db, name.funcName)
- } else {
- tempFunctions.get(name.funcName)
- }
- }
-
- // TODO: implement lookupFunction that returns something from the registry itself
+ def listFunctions(db: String): Seq[FunctionIdentifier] = listFunctions(db, "*")
/**
* List all matching functions in the specified database, including temporary functions.
@@ -509,18 +626,40 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
def listFunctions(db: String, pattern: String): Seq[FunctionIdentifier] = {
val dbFunctions =
externalCatalog.listFunctions(db, pattern).map { f => FunctionIdentifier(f, Some(db)) }
- val regex = pattern.replaceAll("\\*", ".*").r
- val _tempFunctions = tempFunctions.keys().asScala
- .filter { f => regex.pattern.matcher(f).matches() }
+ val loadedFunctions = StringUtils.filterPattern(functionRegistry.listFunction(), pattern)
.map { f => FunctionIdentifier(f) }
- dbFunctions ++ _tempFunctions
+ // TODO: Actually, there will be dbFunctions that have been loaded into the FunctionRegistry.
+ // So, the returned list may have two entries for the same function.
+ dbFunctions ++ loadedFunctions
}
+
+ // -----------------
+ // | Other methods |
+ // -----------------
+
/**
- * Return a temporary function. For testing only.
+ * Drop all existing databases (except "default") along with all associated tables,
+ * partitions and functions, and set the current database to "default".
+ *
+ * This is mainly used for tests.
*/
- private[catalog] def getTempFunction(name: String): Option[CatalogFunction] = {
- Option(tempFunctions.get(name))
+ private[sql] def reset(): Unit = {
+ val default = "default"
+ listDatabases().filter(_ != default).foreach { db =>
+ dropDatabase(db, ignoreIfNotExists = false, cascade = true)
+ }
+ tempTables.clear()
+ functionRegistry.clear()
+ // restore built-in functions
+ FunctionRegistry.builtin.listFunction().foreach { f =>
+ val expressionInfo = FunctionRegistry.builtin.lookupFunction(f)
+ val functionBuilder = FunctionRegistry.builtin.lookupFunctionBuilder(f)
+ require(expressionInfo.isDefined, s"built-in function '$f' is missing expression info")
+ require(functionBuilder.isDefined, s"built-in function '$f' is missing function builder")
+ functionRegistry.registerFunction(f, expressionInfo.get, functionBuilder.get)
+ }
+ setCurrentDatabase(default)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala
new file mode 100644
index 0000000000..5adcc892cf
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.catalog
+
+import org.apache.spark.sql.AnalysisException
+
+/** An trait that represents the type of a resourced needed by a function. */
+sealed trait FunctionResourceType
+
+object JarResource extends FunctionResourceType
+
+object FileResource extends FunctionResourceType
+
+// We do not allow users to specify a archive because it is YARN specific.
+// When loading resources, we will throw an exception and ask users to
+// use --archive with spark submit.
+object ArchiveResource extends FunctionResourceType
+
+object FunctionResourceType {
+ def fromString(resourceType: String): FunctionResourceType = {
+ resourceType.toLowerCase match {
+ case "jar" => JarResource
+ case "file" => FileResource
+ case "archive" => ArchiveResource
+ case other =>
+ throw new AnalysisException(s"Resource Type '$resourceType' is not supported.")
+ }
+ }
+}
+
+case class FunctionResource(resourceType: FunctionResourceType, uri: String)
+
+/**
+ * A simple trait representing a class that can be used to load resources used by
+ * a function. Because only a SQLContext can load resources, we create this trait
+ * to avoid of explicitly passing SQLContext around.
+ */
+trait FunctionResourceLoader {
+ def loadResource(resource: FunctionResource): Unit
+}
+
+object DummyFunctionResourceLoader extends FunctionResourceLoader {
+ override def loadResource(resource: FunctionResource): Unit = {
+ throw new UnsupportedOperationException
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
index 34803133f6..ad989a97e4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
@@ -39,7 +39,7 @@ abstract class ExternalCatalog {
protected def requireDbExists(db: String): Unit = {
if (!databaseExists(db)) {
- throw new AnalysisException(s"Database $db does not exist")
+ throw new AnalysisException(s"Database '$db' does not exist")
}
}
@@ -91,6 +91,8 @@ abstract class ExternalCatalog {
def getTable(db: String, table: String): CatalogTable
+ def getTableOption(db: String, table: String): Option[CatalogTable]
+
def tableExists(db: String, table: String): Boolean
def listTables(db: String): Seq[String]
@@ -150,17 +152,10 @@ abstract class ExternalCatalog {
def renameFunction(db: String, oldName: String, newName: String): Unit
- /**
- * Alter a function whose name that matches the one specified in `funcDefinition`,
- * assuming the function exists.
- *
- * Note: If the underlying implementation does not support altering a certain field,
- * this becomes a no-op.
- */
- def alterFunction(db: String, funcDefinition: CatalogFunction): Unit
-
def getFunction(db: String, funcName: String): CatalogFunction
+ def functionExists(db: String, funcName: String): Boolean
+
def listFunctions(db: String, pattern: String): Seq[String]
}
@@ -169,10 +164,15 @@ abstract class ExternalCatalog {
/**
* A function defined in the catalog.
*
- * @param name name of the function
+ * @param identifier name of the function
* @param className fully qualified class name, e.g. "org.apache.spark.util.MyFunc"
+ * @param resources resource types and Uris used by the function
*/
-case class CatalogFunction(name: FunctionIdentifier, className: String)
+// TODO: Use FunctionResource instead of (String, String) as the element type of resources.
+case class CatalogFunction(
+ identifier: FunctionIdentifier,
+ className: String,
+ resources: Seq[(String, String)])
/**
@@ -216,26 +216,42 @@ case class CatalogTablePartition(
* future once we have a better understanding of how we want to handle skewed columns.
*/
case class CatalogTable(
- name: TableIdentifier,
+ identifier: TableIdentifier,
tableType: CatalogTableType,
storage: CatalogStorageFormat,
schema: Seq[CatalogColumn],
- partitionColumns: Seq[CatalogColumn] = Seq.empty,
- sortColumns: Seq[CatalogColumn] = Seq.empty,
- numBuckets: Int = 0,
+ partitionColumnNames: Seq[String] = Seq.empty,
+ sortColumnNames: Seq[String] = Seq.empty,
+ bucketColumnNames: Seq[String] = Seq.empty,
+ numBuckets: Int = -1,
createTime: Long = System.currentTimeMillis,
- lastAccessTime: Long = System.currentTimeMillis,
+ lastAccessTime: Long = -1,
properties: Map[String, String] = Map.empty,
viewOriginalText: Option[String] = None,
- viewText: Option[String] = None) {
+ viewText: Option[String] = None,
+ comment: Option[String] = None) {
+
+ // Verify that the provided columns are part of the schema
+ private val colNames = schema.map(_.name).toSet
+ private def requireSubsetOfSchema(cols: Seq[String], colType: String): Unit = {
+ require(cols.toSet.subsetOf(colNames), s"$colType columns (${cols.mkString(", ")}) " +
+ s"must be a subset of schema (${colNames.mkString(", ")}) in table '$identifier'")
+ }
+ requireSubsetOfSchema(partitionColumnNames, "partition")
+ requireSubsetOfSchema(sortColumnNames, "sort")
+ requireSubsetOfSchema(bucketColumnNames, "bucket")
+
+ /** Columns this table is partitioned by. */
+ def partitionColumns: Seq[CatalogColumn] =
+ schema.filter { c => partitionColumnNames.contains(c.name) }
/** Return the database this table was specified to belong to, assuming it exists. */
- def database: String = name.database.getOrElse {
- throw new AnalysisException(s"table $name did not specify database")
+ def database: String = identifier.database.getOrElse {
+ throw new AnalysisException(s"table $identifier did not specify database")
}
/** Return the fully qualified name of this table, assuming the database was specified. */
- def qualifiedName: String = name.unquotedString
+ def qualifiedName: String = identifier.unquotedString
/** Syntactic sugar to update a field in `storage`. */
def withNewStorage(
@@ -290,6 +306,6 @@ case class CatalogRelation(
// TODO: implement this
override def output: Seq[Attribute] = Seq.empty
- require(metadata.name.database == Some(db),
- "provided database does not much the one specified in the table definition")
+ require(metadata.identifier.database == Some(db),
+ "provided database does not match the one specified in the table definition")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 3540014c3e..1e7296664b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -21,7 +21,8 @@ import java.sql.{Date, Timestamp}
import scala.language.implicitConversions
-import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedAttribute, UnresolvedExtractValue}
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
@@ -161,6 +162,18 @@ package object dsl {
def lower(e: Expression): Expression = Lower(e)
def sqrt(e: Expression): Expression = Sqrt(e)
def abs(e: Expression): Expression = Abs(e)
+ def star(names: String*): Expression = names match {
+ case Seq() => UnresolvedStar(None)
+ case target => UnresolvedStar(Option(target))
+ }
+
+ def callFunction[T, U](
+ func: T => U,
+ returnType: DataType,
+ argument: Expression): Expression = {
+ val function = Literal.create(func, ObjectType(classOf[T => U]))
+ Invoke(function, "apply", returnType, argument :: Nil)
+ }
implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name }
// TODO more implicit class for literal?
@@ -231,6 +244,12 @@ package object dsl {
AttributeReference(s, structType, nullable = true)()
def struct(attrs: AttributeReference*): AttributeReference =
struct(StructType.fromAttributes(attrs))
+
+ /** Create a function. */
+ def function(exprs: Expression*): UnresolvedFunction =
+ UnresolvedFunction(s, exprs, isDistinct = false)
+ def distinctFunction(exprs: Expression*): UnresolvedFunction =
+ UnresolvedFunction(s, exprs, isDistinct = true)
}
implicit class DslAttribute(a: AttributeReference) {
@@ -243,11 +262,33 @@ package object dsl {
object expressions extends ExpressionConversions // scalastyle:ignore
object plans { // scalastyle:ignore
+ def table(ref: String): LogicalPlan =
+ UnresolvedRelation(TableIdentifier(ref), None)
+
+ def table(db: String, ref: String): LogicalPlan =
+ UnresolvedRelation(TableIdentifier(ref, Option(db)), None)
+
implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) {
- def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan)
+ def select(exprs: Expression*): LogicalPlan = {
+ val namedExpressions = exprs.map {
+ case e: NamedExpression => e
+ case e => UnresolvedAlias(e)
+ }
+ Project(namedExpressions, logicalPlan)
+ }
def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan)
+ def filter[T : Encoder](func: T => Boolean): LogicalPlan = {
+ val deserialized = logicalPlan.deserialize[T]
+ val condition = expressions.callFunction(func, BooleanType, deserialized.output.head)
+ Filter(condition, deserialized).serialize[T]
+ }
+
+ def serialize[T : Encoder]: LogicalPlan = CatalystSerde.serialize[T](logicalPlan)
+
+ def deserialize[T : Encoder]: LogicalPlan = CatalystSerde.deserialize[T](logicalPlan)
+
def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan)
def join(
@@ -296,6 +337,14 @@ package object dsl {
analysis.UnresolvedRelation(TableIdentifier(tableName)),
Map.empty, logicalPlan, overwrite, false)
+ def as(alias: String): LogicalPlan = logicalPlan match {
+ case UnresolvedRelation(tbl, _) => UnresolvedRelation(tbl, Option(alias))
+ case plan => SubqueryAlias(alias, plan)
+ }
+
+ def distribute(exprs: Expression*): LogicalPlan =
+ RepartitionByExpression(exprs, logicalPlan)
+
def analyze: LogicalPlan =
EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(logicalPlan))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 918233ddcd..56d29cfbe1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -24,7 +24,7 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag}
import org.apache.spark.sql.{AnalysisException, Encoder}
import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection}
-import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
@@ -51,8 +51,8 @@ object ExpressionEncoder {
val flat = !classOf[Product].isAssignableFrom(cls)
val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = false)
- val toRowExpression = ScalaReflection.extractorsFor[T](inputObject)
- val fromRowExpression = ScalaReflection.constructorFor[T]
+ val serializer = ScalaReflection.serializerFor[T](inputObject)
+ val deserializer = ScalaReflection.deserializerFor[T]
val schema = ScalaReflection.schemaFor[T] match {
case ScalaReflection.Schema(s: StructType, _) => s
@@ -62,8 +62,8 @@ object ExpressionEncoder {
new ExpressionEncoder[T](
schema,
flat,
- toRowExpression.flatten,
- fromRowExpression,
+ serializer.flatten,
+ deserializer,
ClassTag[T](cls))
}
@@ -72,14 +72,14 @@ object ExpressionEncoder {
val schema = JavaTypeInference.inferDataType(beanClass)._1
assert(schema.isInstanceOf[StructType])
- val toRowExpression = JavaTypeInference.extractorsFor(beanClass)
- val fromRowExpression = JavaTypeInference.constructorFor(beanClass)
+ val serializer = JavaTypeInference.serializerFor(beanClass)
+ val deserializer = JavaTypeInference.deserializerFor(beanClass)
new ExpressionEncoder[T](
schema.asInstanceOf[StructType],
flat = false,
- toRowExpression.flatten,
- fromRowExpression,
+ serializer.flatten,
+ deserializer,
ClassTag[T](beanClass))
}
@@ -103,9 +103,9 @@ object ExpressionEncoder {
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
- val toRowExpressions = encoders.map {
- case e if e.flat => e.toRowExpressions.head
- case other => CreateStruct(other.toRowExpressions)
+ val serializer = encoders.map {
+ case e if e.flat => e.serializer.head
+ case other => CreateStruct(other.serializer)
}.zipWithIndex.map { case (expr, index) =>
expr.transformUp {
case BoundReference(0, t, _) =>
@@ -116,14 +116,14 @@ object ExpressionEncoder {
}
}
- val fromRowExpressions = encoders.zipWithIndex.map { case (enc, index) =>
+ val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) =>
if (enc.flat) {
- enc.fromRowExpression.transform {
+ enc.deserializer.transform {
case b: BoundReference => b.copy(ordinal = index)
}
} else {
val input = BoundReference(index, enc.schema, nullable = true)
- enc.fromRowExpression.transformUp {
+ enc.deserializer.transformUp {
case UnresolvedAttribute(nameParts) =>
assert(nameParts.length == 1)
UnresolvedExtractValue(input, Literal(nameParts.head))
@@ -132,14 +132,14 @@ object ExpressionEncoder {
}
}
- val fromRowExpression =
- NewInstance(cls, fromRowExpressions, ObjectType(cls), propagateNull = false)
+ val deserializer =
+ NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = false)
new ExpressionEncoder[Any](
schema,
flat = false,
- toRowExpressions,
- fromRowExpression,
+ serializer,
+ deserializer,
ClassTag(cls))
}
@@ -174,29 +174,29 @@ object ExpressionEncoder {
* A generic encoder for JVM objects.
*
* @param schema The schema after converting `T` to a Spark SQL row.
- * @param toRowExpressions A set of expressions, one for each top-level field that can be used to
- * extract the values from a raw object into an [[InternalRow]].
- * @param fromRowExpression An expression that will construct an object given an [[InternalRow]].
+ * @param serializer A set of expressions, one for each top-level field that can be used to
+ * extract the values from a raw object into an [[InternalRow]].
+ * @param deserializer An expression that will construct an object given an [[InternalRow]].
* @param clsTag A classtag for `T`.
*/
case class ExpressionEncoder[T](
schema: StructType,
flat: Boolean,
- toRowExpressions: Seq[Expression],
- fromRowExpression: Expression,
+ serializer: Seq[Expression],
+ deserializer: Expression,
clsTag: ClassTag[T])
extends Encoder[T] {
- if (flat) require(toRowExpressions.size == 1)
+ if (flat) require(serializer.size == 1)
@transient
- private lazy val extractProjection = GenerateUnsafeProjection.generate(toRowExpressions)
+ private lazy val extractProjection = GenerateUnsafeProjection.generate(serializer)
@transient
private lazy val inputRow = new GenericMutableRow(1)
@transient
- private lazy val constructProjection = GenerateSafeProjection.generate(fromRowExpression :: Nil)
+ private lazy val constructProjection = GenerateSafeProjection.generate(deserializer :: Nil)
/**
* Returns this encoder where it has been bound to its own output (i.e. no remaping of columns
@@ -212,7 +212,7 @@ case class ExpressionEncoder[T](
* Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form
* of this object.
*/
- def namedExpressions: Seq[NamedExpression] = schema.map(_.name).zip(toRowExpressions).map {
+ def namedExpressions: Seq[NamedExpression] = schema.map(_.name).zip(serializer).map {
case (_, ne: NamedExpression) => ne.newInstance()
case (name, e) => Alias(e, name)()
}
@@ -228,7 +228,7 @@ case class ExpressionEncoder[T](
} catch {
case e: Exception =>
throw new RuntimeException(
- s"Error while encoding: $e\n${toRowExpressions.map(_.treeString).mkString("\n")}", e)
+ s"Error while encoding: $e\n${serializer.map(_.treeString).mkString("\n")}", e)
}
/**
@@ -240,7 +240,7 @@ case class ExpressionEncoder[T](
constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T]
} catch {
case e: Exception =>
- throw new RuntimeException(s"Error while decoding: $e\n${fromRowExpression.treeString}", e)
+ throw new RuntimeException(s"Error while decoding: $e\n${deserializer.treeString}", e)
}
/**
@@ -249,7 +249,7 @@ case class ExpressionEncoder[T](
* has not been done already in places where we plan to do later composition of encoders.
*/
def assertUnresolved(): Unit = {
- (fromRowExpression +: toRowExpressions).foreach(_.foreach {
+ (deserializer +: serializer).foreach(_.foreach {
case a: AttributeReference if a.name != "loopVar" =>
sys.error(s"Unresolved encoder expected, but $a was found.")
case _ =>
@@ -257,7 +257,7 @@ case class ExpressionEncoder[T](
}
/**
- * Validates `fromRowExpression` to make sure it can be resolved by given schema, and produce
+ * Validates `deserializer` to make sure it can be resolved by given schema, and produce
* friendly error messages to explain why it fails to resolve if there is something wrong.
*/
def validate(schema: Seq[Attribute]): Unit = {
@@ -271,7 +271,7 @@ case class ExpressionEncoder[T](
// If this is a tuple encoder or tupled encoder, which means its leaf nodes are all
// `BoundReference`, make sure their ordinals are all valid.
var maxOrdinal = -1
- fromRowExpression.foreach {
+ deserializer.foreach {
case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal
case _ =>
}
@@ -285,7 +285,7 @@ case class ExpressionEncoder[T](
// we unbound it by the given `schema` and propagate the actual type to `GetStructField`, after
// we resolve the `fromRowExpression`.
val resolved = SimpleAnalyzer.resolveExpression(
- fromRowExpression,
+ deserializer,
LocalRelation(schema),
throws = true)
@@ -312,42 +312,39 @@ case class ExpressionEncoder[T](
}
/**
- * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the
- * given schema.
+ * Returns a new copy of this encoder, where the `deserializer` is resolved to the given schema.
*/
def resolve(
schema: Seq[Attribute],
outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = {
- val deserializer = SimpleAnalyzer.ResolveReferences.resolveDeserializer(
- fromRowExpression, schema)
-
// Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check
// analysis, go through optimizer, etc.
- val plan = Project(Alias(deserializer, "")() :: Nil, LocalRelation(schema))
+ val plan = Project(
+ Alias(UnresolvedDeserializer(deserializer, schema), "")() :: Nil,
+ LocalRelation(schema))
val analyzedPlan = SimpleAnalyzer.execute(plan)
SimpleAnalyzer.checkAnalysis(analyzedPlan)
- copy(fromRowExpression = SimplifyCasts(analyzedPlan).expressions.head.children.head)
+ copy(deserializer = SimplifyCasts(analyzedPlan).expressions.head.children.head)
}
/**
- * Returns a copy of this encoder where the expressions used to construct an object from an input
- * row have been bound to the ordinals of the given schema. Note that you need to first call
- * resolve before bind.
+ * Returns a copy of this encoder where the `deserializer` has been bound to the
+ * ordinals of the given schema. Note that you need to first call resolve before bind.
*/
def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
- copy(fromRowExpression = BindReferences.bindReference(fromRowExpression, schema))
+ copy(deserializer = BindReferences.bindReference(deserializer, schema))
}
/**
* Returns a new encoder with input columns shifted by `delta` ordinals
*/
def shift(delta: Int): ExpressionEncoder[T] = {
- copy(fromRowExpression = fromRowExpression transform {
+ copy(deserializer = deserializer transform {
case r: BoundReference => r.copy(ordinal = r.ordinal + delta)
})
}
- protected val attrs = toRowExpressions.flatMap(_.collect {
+ protected val attrs = serializer.flatMap(_.collect {
case _: UnresolvedAttribute => ""
case a: Attribute => s"#${a.exprId}"
case b: BoundReference => s"[${b.ordinal}]"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 30f56d8c2f..a8397aa5e5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -36,23 +36,23 @@ object RowEncoder {
val cls = classOf[Row]
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
// We use an If expression to wrap extractorsFor result of StructType
- val extractExpressions = extractorsFor(inputObject, schema).asInstanceOf[If].falseValue
- val constructExpression = constructorFor(schema)
+ val serializer = serializerFor(inputObject, schema).asInstanceOf[If].falseValue
+ val deserializer = deserializerFor(schema)
new ExpressionEncoder[Row](
schema,
flat = false,
- extractExpressions.asInstanceOf[CreateStruct].children,
- constructExpression,
+ serializer.asInstanceOf[CreateStruct].children,
+ deserializer,
ClassTag(cls))
}
- private def extractorsFor(
+ private def serializerFor(
inputObject: Expression,
inputType: DataType): Expression = inputType match {
case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType | CalendarIntervalType => inputObject
- case p: PythonUserDefinedType => extractorsFor(inputObject, p.sqlType)
+ case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType)
case udt: UserDefinedType[_] =>
val obj = NewInstance(
@@ -95,7 +95,7 @@ object RowEncoder {
classOf[GenericArrayData],
inputObject :: Nil,
dataType = t)
- case _ => MapObjects(extractorsFor(_, et), inputObject, externalDataTypeForInput(et))
+ case _ => MapObjects(serializerFor(_, et), inputObject, externalDataTypeForInput(et))
}
case t @ MapType(kt, vt, valueNullable) =>
@@ -104,14 +104,14 @@ object RowEncoder {
Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]])),
"toSeq",
ObjectType(classOf[scala.collection.Seq[_]]))
- val convertedKeys = extractorsFor(keys, ArrayType(kt, false))
+ val convertedKeys = serializerFor(keys, ArrayType(kt, false))
val values =
Invoke(
Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]])),
"toSeq",
ObjectType(classOf[scala.collection.Seq[_]]))
- val convertedValues = extractorsFor(values, ArrayType(vt, valueNullable))
+ val convertedValues = serializerFor(values, ArrayType(vt, valueNullable))
NewInstance(
classOf[ArrayBasedMapData],
@@ -128,7 +128,7 @@ object RowEncoder {
If(
Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
Literal.create(null, f.dataType),
- extractorsFor(
+ serializerFor(
Invoke(inputObject, method, externalDataTypeForInput(f.dataType), Literal(i) :: Nil),
f.dataType))
}
@@ -166,7 +166,7 @@ object RowEncoder {
case _: NullType => ObjectType(classOf[java.lang.Object])
}
- private def constructorFor(schema: StructType): Expression = {
+ private def deserializerFor(schema: StructType): Expression = {
val fields = schema.zipWithIndex.map { case (f, i) =>
val dt = f.dataType match {
case p: PythonUserDefinedType => p.sqlType
@@ -176,13 +176,13 @@ object RowEncoder {
If(
IsNull(field),
Literal.create(null, externalDataTypeFor(dt)),
- constructorFor(field)
+ deserializerFor(field)
)
}
CreateExternalRow(fields, schema)
}
- private def constructorFor(input: Expression): Expression = input.dataType match {
+ private def deserializerFor(input: Expression): Expression = input.dataType match {
case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType | CalendarIntervalType => input
@@ -216,7 +216,7 @@ object RowEncoder {
case ArrayType(et, nullable) =>
val arrayData =
Invoke(
- MapObjects(constructorFor(_), input, et),
+ MapObjects(deserializerFor(_), input, et),
"array",
ObjectType(classOf[Array[_]]))
StaticInvoke(
@@ -227,10 +227,10 @@ object RowEncoder {
case MapType(kt, vt, valueNullable) =>
val keyArrayType = ArrayType(kt, false)
- val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType))
+ val keyData = deserializerFor(Invoke(input, "keyArray", keyArrayType))
val valueArrayType = ArrayType(vt, valueNullable)
- val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType))
+ val valueData = deserializerFor(Invoke(input, "valueArray", valueArrayType))
StaticInvoke(
ArrayBasedMapData.getClass,
@@ -243,7 +243,7 @@ object RowEncoder {
If(
Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil),
Literal.create(null, externalDataTypeFor(f.dataType)),
- constructorFor(GetStructField(input, i)))
+ deserializerFor(GetStructField(input, i)))
}
If(IsNull(input),
Literal.create(null, externalDataTypeFor(input.dataType)),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala
index 0d44d1dd96..0420b4b538 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala
@@ -25,15 +25,18 @@ import org.apache.spark.sql.catalyst.trees.TreeNode
package object errors {
class TreeNodeException[TreeType <: TreeNode[_]](
- tree: TreeType, msg: String, cause: Throwable)
+ @transient val tree: TreeType,
+ msg: String,
+ cause: Throwable)
extends Exception(msg, cause) {
+ val treeString = tree.toString
+
// Yes, this is the same as a default parameter, but... those don't seem to work with SBT
// external project dependencies for some reason.
def this(tree: TreeType, msg: String) = this(tree, msg, null)
override def getMessage: String = {
- val treeString = tree.toString
s"${super.getMessage}, tree:${if (treeString contains "\n") "\n" else " "}$tree"
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index a965cc8d53..0f8876a9e6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -112,7 +112,7 @@ object Cast {
}
/** Cast the child expression to the target data type. */
-case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
+case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with NullIntolerant {
override def toString: String = s"cast($child as ${dataType.simpleString})"
@@ -898,7 +898,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
val result = ctx.freshName("result")
val tmpRow = ctx.freshName("tmpRow")
- val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => {
+ val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) =>
val fromFieldPrim = ctx.freshName("ffp")
val fromFieldNull = ctx.freshName("ffn")
val toFieldPrim = ctx.freshName("tfp")
@@ -920,7 +920,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
}
}
"""
- }
}.mkString("\n")
(c, evPrim, evNull) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
index affd1bdb32..8d8cc152ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
@@ -97,11 +97,11 @@ class EquivalentExpressions {
def debugString(all: Boolean = false): String = {
val sb: mutable.StringBuilder = new StringBuilder()
sb.append("Equivalent expressions:\n")
- equivalenceMap.foreach { case (k, v) => {
+ equivalenceMap.foreach { case (k, v) =>
if (all || v.length > 1) {
sb.append(" " + v.mkString(", ")).append("\n")
}
- }}
+ }
sb.toString()
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 5f8899d599..a24a5db8d4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -153,8 +153,8 @@ abstract class Expression extends TreeNode[Expression] {
* evaluate to the same result.
*/
lazy val canonicalized: Expression = {
- val canonicalizedChildred = children.map(_.canonicalized)
- Canonicalize.execute(withNewChildren(canonicalizedChildred))
+ val canonicalizedChildren = children.map(_.canonicalized)
+ Canonicalize.execute(withNewChildren(canonicalizedChildren))
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
index dbd0acf06c..2ed6fc0d38 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
@@ -17,14 +17,14 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.rdd.SqlNewHadoopRDDState
+import org.apache.spark.rdd.InputFileNameHolder
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, StringType}
import org.apache.spark.unsafe.types.UTF8String
/**
- * Expression that returns the name of the current file being read in using [[SqlNewHadoopRDD]]
+ * Expression that returns the name of the current file being read.
*/
@ExpressionDescription(
usage = "_FUNC_() - Returns the name of the current file being read if available",
@@ -40,12 +40,12 @@ case class InputFileName() extends LeafExpression with Nondeterministic {
override protected def initInternal(): Unit = {}
override protected def evalInternal(input: InternalRow): UTF8String = {
- SqlNewHadoopRDDState.getInputFileName()
+ InputFileNameHolder.getInputFileName()
}
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
ev.isNull = "false"
s"final ${ctx.javaType(dataType)} ${ev.value} = " +
- "org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();"
+ "org.apache.spark.rdd.InputFileNameHolder.getInputFileName();"
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 053e612f3e..354311c5e7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -136,9 +136,9 @@ object UnsafeProjection {
}
/**
- * Same as other create()'s but allowing enabling/disabling subexpression elimination.
- * TODO: refactor the plumbing and clean this up.
- */
+ * Same as other create()'s but allowing enabling/disabling subexpression elimination.
+ * TODO: refactor the plumbing and clean this up.
+ */
def create(
exprs: Seq[Expression],
inputSchema: Seq[Attribute],
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
index 4615c55d67..61ca7272df 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
@@ -62,7 +62,7 @@ import org.apache.spark.sql.types._
abstract class MutableValue extends Serializable {
var isNull: Boolean = true
def boxed: Any
- def update(v: Any)
+ def update(v: Any): Unit
def copy(): MutableValue
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
new file mode 100644
index 0000000000..daf3de95dd
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.commons.lang.StringUtils
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
+
+case class TimeWindow(
+ timeColumn: Expression,
+ windowDuration: Long,
+ slideDuration: Long,
+ startTime: Long) extends UnaryExpression
+ with ImplicitCastInputTypes
+ with Unevaluable
+ with NonSQLExpression {
+
+ //////////////////////////
+ // SQL Constructors
+ //////////////////////////
+
+ def this(
+ timeColumn: Expression,
+ windowDuration: Expression,
+ slideDuration: Expression,
+ startTime: Expression) = {
+ this(timeColumn, TimeWindow.parseExpression(windowDuration),
+ TimeWindow.parseExpression(windowDuration), TimeWindow.parseExpression(startTime))
+ }
+
+ def this(timeColumn: Expression, windowDuration: Expression, slideDuration: Expression) = {
+ this(timeColumn, TimeWindow.parseExpression(windowDuration),
+ TimeWindow.parseExpression(windowDuration), 0)
+ }
+
+ def this(timeColumn: Expression, windowDuration: Expression) = {
+ this(timeColumn, windowDuration, windowDuration)
+ }
+
+ override def child: Expression = timeColumn
+ override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
+ override def dataType: DataType = new StructType()
+ .add(StructField("start", TimestampType))
+ .add(StructField("end", TimestampType))
+
+ // This expression is replaced in the analyzer.
+ override lazy val resolved = false
+
+ /**
+ * Validate the inputs for the window duration, slide duration, and start time in addition to
+ * the input data type.
+ */
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val dataTypeCheck = super.checkInputDataTypes()
+ if (dataTypeCheck.isSuccess) {
+ if (windowDuration <= 0) {
+ return TypeCheckFailure(s"The window duration ($windowDuration) must be greater than 0.")
+ }
+ if (slideDuration <= 0) {
+ return TypeCheckFailure(s"The slide duration ($slideDuration) must be greater than 0.")
+ }
+ if (startTime < 0) {
+ return TypeCheckFailure(s"The start time ($startTime) must be greater than or equal to 0.")
+ }
+ if (slideDuration > windowDuration) {
+ return TypeCheckFailure(s"The slide duration ($slideDuration) must be less than or equal" +
+ s" to the windowDuration ($windowDuration).")
+ }
+ if (startTime >= slideDuration) {
+ return TypeCheckFailure(s"The start time ($startTime) must be less than the " +
+ s"slideDuration ($slideDuration).")
+ }
+ }
+ dataTypeCheck
+ }
+}
+
+object TimeWindow {
+ /**
+ * Parses the interval string for a valid time duration. CalendarInterval expects interval
+ * strings to start with the string `interval`. For usability, we prepend `interval` to the string
+ * if the user omitted it.
+ *
+ * @param interval The interval string
+ * @return The interval duration in microseconds. SparkSQL casts TimestampType has microsecond
+ * precision.
+ */
+ private def getIntervalInMicroSeconds(interval: String): Long = {
+ if (StringUtils.isBlank(interval)) {
+ throw new IllegalArgumentException(
+ "The window duration, slide duration and start time cannot be null or blank.")
+ }
+ val intervalString = if (interval.startsWith("interval")) {
+ interval
+ } else {
+ "interval " + interval
+ }
+ val cal = CalendarInterval.fromString(intervalString)
+ if (cal == null) {
+ throw new IllegalArgumentException(
+ s"The provided interval ($interval) did not correspond to a valid interval string.")
+ }
+ if (cal.months > 0) {
+ throw new IllegalArgumentException(
+ s"Intervals greater than a month is not supported ($interval).")
+ }
+ cal.microseconds
+ }
+
+ /**
+ * Parses the duration expression to generate the long value for the original constructor so
+ * that we can use `window` in SQL.
+ */
+ private def parseExpression(expr: Expression): Long = expr match {
+ case NonNullLiteral(s, StringType) => getIntervalInMicroSeconds(s.toString)
+ case IntegerLiteral(i) => i.toLong
+ case NonNullLiteral(l, LongType) => l.toString.toLong
+ case _ => throw new AnalysisException("The duration and time inputs to window must be " +
+ "an integer, long or string literal.")
+ }
+
+ def apply(
+ timeColumn: Expression,
+ windowDuration: String,
+ slideDuration: String,
+ startTime: String): TimeWindow = {
+ TimeWindow(timeColumn,
+ getIntervalInMicroSeconds(windowDuration),
+ getIntervalInMicroSeconds(slideDuration),
+ getIntervalInMicroSeconds(startTime))
+ }
+}
+
+/**
+ * Expression used internally to convert the TimestampType to Long without losing
+ * precision, i.e. in microseconds. Used in time windowing.
+ */
+case class PreciseTimestamp(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+ override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
+ override def dataType: DataType = LongType
+ override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
+ val eval = child.gen(ctx)
+ eval.code +
+ s"""boolean ${ev.isNull} = ${eval.isNull};
+ |${ctx.javaType(dataType)} ${ev.value} = ${eval.value};
+ """.stripMargin
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index 94ac4bf09b..ff70774847 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -23,6 +23,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the mean calculated from values of a group.")
case class Average(child: Expression) extends DeclarativeAggregate {
override def prettyName: String = "avg"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
index 9d2db45144..17a7c6dce8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
@@ -130,6 +130,10 @@ abstract class CentralMomentAgg(child: Expression) extends DeclarativeAggregate
}
// Compute the population standard deviation of a column
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the population standard deviation calculated from values of a group.")
+// scalastyle:on line.size.limit
case class StddevPop(child: Expression) extends CentralMomentAgg(child) {
override protected def momentOrder = 2
@@ -143,6 +147,8 @@ case class StddevPop(child: Expression) extends CentralMomentAgg(child) {
}
// Compute the sample standard deviation of a column
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the sample standard deviation calculated from values of a group.")
case class StddevSamp(child: Expression) extends CentralMomentAgg(child) {
override protected def momentOrder = 2
@@ -157,6 +163,8 @@ case class StddevSamp(child: Expression) extends CentralMomentAgg(child) {
}
// Compute the population variance of a column
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the population variance calculated from values of a group.")
case class VariancePop(child: Expression) extends CentralMomentAgg(child) {
override protected def momentOrder = 2
@@ -170,6 +178,8 @@ case class VariancePop(child: Expression) extends CentralMomentAgg(child) {
}
// Compute the sample variance of a column
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the sample variance calculated from values of a group.")
case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) {
override protected def momentOrder = 2
@@ -183,6 +193,8 @@ case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) {
override def prettyName: String = "var_samp"
}
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the Skewness value calculated from values of a group.")
case class Skewness(child: Expression) extends CentralMomentAgg(child) {
override def prettyName: String = "skewness"
@@ -196,6 +208,8 @@ case class Skewness(child: Expression) extends CentralMomentAgg(child) {
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the Kurtosis value calculated from values of a group.")
case class Kurtosis(child: Expression) extends CentralMomentAgg(child) {
override protected def momentOrder = 4
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
index e6b8214ef2..e29265e2f4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
@@ -28,6 +28,8 @@ import org.apache.spark.sql.types._
* Definition of Pearson correlation can be found at
* http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient
*/
+@ExpressionDescription(
+ usage = "_FUNC_(x,y) - Returns Pearson coefficient of correlation between a set of number pairs.")
case class Corr(x: Expression, y: Expression) extends DeclarativeAggregate {
override def children: Seq[Expression] = Seq(x, y)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
index 663c69e799..17ae012af7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
@@ -21,6 +21,12 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """_FUNC_(*) - Returns the total number of retrieved rows, including rows containing NULL values.
+ _FUNC_(expr) - Returns the number of rows for which the supplied expression is non-NULL.
+ _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are unique and non-NULL.""")
+// scalastyle:on line.size.limit
case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
override def nullable: Boolean = false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
index c175a8c4c7..d80afbebf7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
@@ -76,6 +76,8 @@ abstract class Covariance(x: Expression, y: Expression) extends DeclarativeAggre
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(x,y) - Returns the population covariance of a set of number pairs.")
case class CovPopulation(left: Expression, right: Expression) extends Covariance(left, right) {
override val evaluateExpression: Expression = {
If(n === Literal(0.0), Literal.create(null, DoubleType),
@@ -85,6 +87,8 @@ case class CovPopulation(left: Expression, right: Expression) extends Covariance
}
+@ExpressionDescription(
+ usage = "_FUNC_(x,y) - Returns the sample covariance of a set of number pairs.")
case class CovSample(left: Expression, right: Expression) extends Covariance(left, right) {
override val evaluateExpression: Expression = {
If(n === Literal(0.0), Literal.create(null, DoubleType),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
index 35f57426fe..b8ab0364dd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
@@ -28,6 +28,11 @@ import org.apache.spark.sql.types._
* is used) its result will not be deterministic (unless the input table is sorted and has
* a single partition, and we use a single reducer to do the aggregation.).
*/
+@ExpressionDescription(
+ usage = """_FUNC_(expr) - Returns the first value of `child` for a group of rows.
+ _FUNC_(expr,isIgnoreNull=false) - Returns the first value of `child` for a group of rows.
+ If isIgnoreNull is true, returns only non-null values.
+ """)
case class First(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate {
def this(child: Expression) = this(child, Literal.create(false, BooleanType))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
index b6bd56cff6..1d218da6db 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
@@ -20,8 +20,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import java.lang.{Long => JLong}
import java.util
-import com.clearspring.analytics.hash.MurmurHash
-
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
@@ -48,6 +46,11 @@ import org.apache.spark.sql.types._
* @param relativeSD the maximum estimation error allowed.
*/
// scalastyle:on
+@ExpressionDescription(
+ usage = """_FUNC_(expr) - Returns the estimated cardinality by HyperLogLog++.
+ _FUNC_(expr, relativeSD=0.05) - Returns the estimated cardinality by HyperLogLog++
+ with relativeSD, the maximum estimation error allowed.
+ """)
case class HyperLogLogPlusPlus(
child: Expression,
relativeSD: Double = 0.05,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
index be7e12d7a2..b05d74b49b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
@@ -28,6 +28,8 @@ import org.apache.spark.sql.types._
* is used) its result will not be deterministic (unless the input table is sorted and has
* a single partition, and we use a single reducer to do the aggregation.).
*/
+@ExpressionDescription(
+ usage = "_FUNC_(expr,isIgnoreNull) - Returns the last value of `child` for a group of rows.")
case class Last(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate {
def this(child: Expression) = this(child, Literal.create(false, BooleanType))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
index 906003188d..c534fe495f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
@@ -22,6 +22,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
+@ExpressionDescription(
+ usage = "_FUNC_(expr) - Returns the maximum value of expr.")
case class Max(child: Expression) extends DeclarativeAggregate {
override def children: Seq[Expression] = child :: Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
index 39f7afbd08..35289b4681 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
@@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
-
+@ExpressionDescription(
+ usage = "_FUNC_(expr) - Returns the minimum value of expr.")
case class Min(child: Expression) extends DeclarativeAggregate {
override def children: Seq[Expression] = child :: Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index 08a67ea3df..ad217f25b5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -22,6 +22,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the sum calculated from values of a group.")
case class Sum(child: Expression) extends DeclarativeAggregate {
override def children: Seq[Expression] = child :: Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index ff3064ac66..d31ccf9985 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.types._
@@ -66,6 +67,19 @@ private[sql] case object NoOp extends Expression with Unevaluable {
override def children: Seq[Expression] = Nil
}
+object AggregateExpression {
+ def apply(
+ aggregateFunction: AggregateFunction,
+ mode: AggregateMode,
+ isDistinct: Boolean): AggregateExpression = {
+ AggregateExpression(
+ aggregateFunction,
+ mode,
+ isDistinct,
+ NamedExpression.newExprId)
+ }
+}
+
/**
* A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field
* (`isDistinct`) indicating if DISTINCT keyword is specified for this function.
@@ -73,10 +87,31 @@ private[sql] case object NoOp extends Expression with Unevaluable {
private[sql] case class AggregateExpression(
aggregateFunction: AggregateFunction,
mode: AggregateMode,
- isDistinct: Boolean)
+ isDistinct: Boolean,
+ resultId: ExprId)
extends Expression
with Unevaluable {
+ lazy val resultAttribute: Attribute = if (aggregateFunction.resolved) {
+ AttributeReference(
+ aggregateFunction.toString,
+ aggregateFunction.dataType,
+ aggregateFunction.nullable)(exprId = resultId)
+ } else {
+ // This is a bit of a hack. Really we should not be constructing this container and reasoning
+ // about datatypes / aggregation mode until after we have finished analysis and made it to
+ // planning.
+ UnresolvedAttribute(aggregateFunction.toString)
+ }
+
+ // We compute the same thing regardless of our final result.
+ override lazy val canonicalized: Expression =
+ AggregateExpression(
+ aggregateFunction.canonicalized.asInstanceOf[AggregateFunction],
+ mode,
+ isDistinct,
+ ExprId(0))
+
override def children: Seq[Expression] = aggregateFunction :: Nil
override def dataType: DataType = aggregateFunction.dataType
override def foldable: Boolean = false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index ed812e0679..f3d42fc0b2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -23,8 +23,10 @@ import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
-
-case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Returns -a.")
+case class UnaryMinus(child: Expression) extends UnaryExpression
+ with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
@@ -58,7 +60,10 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp
override def sql: String = s"(-${child.sql})"
}
-case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Returns a.")
+case class UnaryPositive(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def prettyName: String = "positive"
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
@@ -77,9 +82,10 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects
* A function that get the absolute value of the numeric value.
*/
@ExpressionDescription(
- usage = "_FUNC_(expr) - Returns the absolute value of the numeric value",
- extended = "> SELECT _FUNC_('-1');\n1")
-case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+ usage = "_FUNC_(expr) - Returns the absolute value of the numeric value.",
+ extended = "> SELECT _FUNC_('-1');\n 1")
+case class Abs(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
@@ -123,7 +129,9 @@ private[sql] object BinaryArithmetic {
def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right))
}
-case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Returns a+b.")
+case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant {
override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
@@ -152,7 +160,10 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
}
}
-case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Returns a-b.")
+case class Subtract(left: Expression, right: Expression)
+ extends BinaryArithmetic with NullIntolerant {
override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
@@ -181,7 +192,10 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
}
}
-case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Multiplies a by b.")
+case class Multiply(left: Expression, right: Expression)
+ extends BinaryArithmetic with NullIntolerant {
override def inputType: AbstractDataType = NumericType
@@ -193,7 +207,11 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
}
-case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Divides a by b.",
+ extended = "> SELECT 3 _FUNC_ 2;\n 1.5")
+case class Divide(left: Expression, right: Expression)
+ extends BinaryArithmetic with NullIntolerant {
override def inputType: AbstractDataType = NumericType
@@ -237,25 +255,42 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
} else {
s"($javaType)(${eval1.value} $symbol ${eval2.value})"
}
- s"""
- ${eval2.code}
- boolean ${ev.isNull} = false;
- $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
- if (${eval2.isNull} || $isZero) {
- ${ev.isNull} = true;
- } else {
- ${eval1.code}
- if (${eval1.isNull}) {
+ if (!left.nullable && !right.nullable) {
+ s"""
+ ${eval2.code}
+ boolean ${ev.isNull} = false;
+ $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+ if ($isZero) {
${ev.isNull} = true;
} else {
+ ${eval1.code}
${ev.value} = $divide;
}
- }
- """
+ """
+ } else {
+ s"""
+ ${eval2.code}
+ boolean ${ev.isNull} = false;
+ $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+ if (${eval2.isNull} || $isZero) {
+ ${ev.isNull} = true;
+ } else {
+ ${eval1.code}
+ if (${eval1.isNull}) {
+ ${ev.isNull} = true;
+ } else {
+ ${ev.value} = $divide;
+ }
+ }
+ """
+ }
}
}
-case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Returns the remainder when dividing a by b.")
+case class Remainder(left: Expression, right: Expression)
+ extends BinaryArithmetic with NullIntolerant {
override def inputType: AbstractDataType = NumericType
@@ -299,21 +334,35 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
} else {
s"($javaType)(${eval1.value} $symbol ${eval2.value})"
}
- s"""
- ${eval2.code}
- boolean ${ev.isNull} = false;
- $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
- if (${eval2.isNull} || $isZero) {
- ${ev.isNull} = true;
- } else {
- ${eval1.code}
- if (${eval1.isNull}) {
+ if (!left.nullable && !right.nullable) {
+ s"""
+ ${eval2.code}
+ boolean ${ev.isNull} = false;
+ $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+ if ($isZero) {
${ev.isNull} = true;
} else {
+ ${eval1.code}
${ev.value} = $remainder;
}
- }
- """
+ """
+ } else {
+ s"""
+ ${eval2.code}
+ boolean ${ev.isNull} = false;
+ $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+ if (${eval2.isNull} || $isZero) {
+ ${ev.isNull} = true;
+ } else {
+ ${eval1.code}
+ if (${eval1.isNull}) {
+ ${ev.isNull} = true;
+ } else {
+ ${ev.value} = $remainder;
+ }
+ }
+ """
+ }
}
}
@@ -429,7 +478,10 @@ case class MinOf(left: Expression, right: Expression)
override def symbol: String = "min"
}
-case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
+@ExpressionDescription(
+ usage = "_FUNC_(a, b) - Returns the positive modulo",
+ extended = "> SELECT _FUNC_(10,3);\n 1")
+case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant {
override def toString: String = s"pmod($left, $right)"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
index 4c90b3f7d3..a7e1cd66f2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
@@ -26,6 +26,9 @@ import org.apache.spark.sql.types._
*
* Code generation inherited from BinaryArithmetic.
*/
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Bitwise AND.",
+ extended = "> SELECT 3 _FUNC_ 5; 1")
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {
override def inputType: AbstractDataType = IntegralType
@@ -51,6 +54,9 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
*
* Code generation inherited from BinaryArithmetic.
*/
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Bitwise OR.",
+ extended = "> SELECT 3 _FUNC_ 5; 7")
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {
override def inputType: AbstractDataType = IntegralType
@@ -76,6 +82,9 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
*
* Code generation inherited from BinaryArithmetic.
*/
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Bitwise exclusive OR.",
+ extended = "> SELECT 3 _FUNC_ 5; 2")
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
override def inputType: AbstractDataType = IntegralType
@@ -99,6 +108,9 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
/**
* A function that calculates bitwise not(~) of a number.
*/
+@ExpressionDescription(
+ usage = "_FUNC_ b - Bitwise NOT.",
+ extended = "> SELECT _FUNC_ 0; -1")
case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala
index 9d99bbffbe..ab4831f7ab 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala
@@ -43,15 +43,45 @@ object CodeFormatter {
private class CodeFormatter {
private val code = new StringBuilder
- private var indentLevel = 0
private val indentSize = 2
+
+ // Tracks the level of indentation in the current line.
+ private var indentLevel = 0
private var indentString = ""
private var currentLine = 1
+ // Tracks the level of indentation in multi-line comment blocks.
+ private var inCommentBlock = false
+ private var indentLevelOutsideCommentBlock = indentLevel
+
private def addLine(line: String): Unit = {
- val indentChange =
- line.count(c => "({".indexOf(c) >= 0) - line.count(c => ")}".indexOf(c) >= 0)
- val newIndentLevel = math.max(0, indentLevel + indentChange)
+
+ // We currently infer the level of indentation of a given line based on a simple heuristic that
+ // examines the number of parenthesis and braces in that line. This isn't the most robust
+ // implementation but works for all code that we generate.
+ val indentChange = line.count(c => "({".indexOf(c) >= 0) - line.count(c => ")}".indexOf(c) >= 0)
+ var newIndentLevel = math.max(0, indentLevel + indentChange)
+
+ // Please note that while we try to format the comment blocks in exactly the same way as the
+ // rest of the code, once the block ends, we reset the next line's indentation level to what it
+ // was immediately before entering the comment block.
+ if (!inCommentBlock) {
+ if (line.startsWith("/*")) {
+ // Handle multi-line comments
+ inCommentBlock = true
+ indentLevelOutsideCommentBlock = indentLevel
+ } else if (line.startsWith("//")) {
+ // Handle single line comments
+ newIndentLevel = indentLevel
+ }
+ }
+ if (inCommentBlock) {
+ if (line.endsWith("*/")) {
+ inCommentBlock = false
+ newIndentLevel = indentLevelOutsideCommentBlock
+ }
+ }
+
// Lines starting with '}' should be de-indented even if they contain '{' after;
// in addition, lines ending with ':' are typically labels
val thisLineIndent = if (line.startsWith("}") || line.startsWith(")") || line.endsWith(":")) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index b511b4b3a0..f43626ca81 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -58,10 +58,10 @@ class CodegenContext {
val references: mutable.ArrayBuffer[Any] = new mutable.ArrayBuffer[Any]()
/**
- * Add an object to `references`, create a class member to access it.
- *
- * Returns the name of class member.
- */
+ * Add an object to `references`, create a class member to access it.
+ *
+ * Returns the name of class member.
+ */
def addReferenceObj(name: String, obj: Any, className: String = null): String = {
val term = freshName(name)
val idx = references.length
@@ -72,9 +72,9 @@ class CodegenContext {
}
/**
- * Holding a list of generated columns as input of current operator, will be used by
- * BoundReference to generate code.
- */
+ * Holding a list of generated columns as input of current operator, will be used by
+ * BoundReference to generate code.
+ */
var currentVars: Seq[ExprCode] = null
/**
@@ -169,14 +169,14 @@ class CodegenContext {
final var INPUT_ROW = "i"
/**
- * The map from a variable name to it's next ID.
- */
+ * The map from a variable name to it's next ID.
+ */
private val freshNameIds = new mutable.HashMap[String, Int]
freshNameIds += INPUT_ROW -> 1
/**
- * A prefix used to generate fresh name.
- */
+ * A prefix used to generate fresh name.
+ */
var freshNamePrefix = ""
/**
@@ -234,8 +234,8 @@ class CodegenContext {
}
/**
- * Update a column in MutableRow from ExprCode.
- */
+ * Update a column in MutableRow from ExprCode.
+ */
def updateColumn(
row: String,
dataType: DataType,
@@ -509,7 +509,7 @@ class CodegenContext {
/**
* Checks and sets up the state and codegen for subexpression elimination. This finds the
- * common subexpresses, generates the functions that evaluate those expressions and populates
+ * common subexpressions, generates the functions that evaluate those expressions and populates
* the mapping of common subexpressions to the generated functions.
*/
private def subexpressionElimination(expressions: Seq[Expression]) = {
@@ -519,7 +519,7 @@ class CodegenContext {
// Get all the expressions that appear at least twice and set up the state for subexpression
// elimination.
val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
- commonExprs.foreach(e => {
+ commonExprs.foreach { e =>
val expr = e.head
val fnName = freshName("evalExpr")
val isNull = s"${fnName}IsNull"
@@ -561,7 +561,7 @@ class CodegenContext {
subexprFunctions += s"$fnName($INPUT_ROW);"
val state = SubExprEliminationState(isNull, value)
e.foreach(subExprEliminationExprs.put(_, state))
- })
+ }
}
/**
@@ -626,15 +626,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
object CodeGenerator extends Logging {
/**
- * Compile the Java source code into a Java class, using Janino.
- */
+ * Compile the Java source code into a Java class, using Janino.
+ */
def compile(code: String): GeneratedClass = {
cache.get(code)
}
/**
- * Compile the Java source code into a Java class, using Janino.
- */
+ * Compile the Java source code into a Java class, using Janino.
+ */
private[this] def doCompile(code: String): GeneratedClass = {
val evaluator = new ClassBodyEvaluator()
evaluator.setParentClassLoader(Utils.getContextOrSparkClassLoader)
@@ -661,7 +661,7 @@ object CodeGenerator extends Logging {
logDebug({
// Only add extra debugging info to byte code when we are going to print the source code.
evaluator.setDebuggingInformation(true, true, false)
- formatted
+ s"\n$formatted"
})
try {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index e36c985249..ab790cf372 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -26,6 +26,8 @@ import org.apache.spark.sql.types._
/**
* Given an array or map, returns its size.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(expr) - Returns the size of an array or a map.")
case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = IntegerType
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType))
@@ -44,6 +46,11 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
* Sorts the input array in ascending / descending order according to the natural ordering of
* the array elements and returns it.
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(array(obj1, obj2,...)) - Sorts the input array in ascending order according to the natural ordering of the array elements.",
+ extended = " > SELECT _FUNC_(array('b', 'd', 'c', 'a'));\n 'a', 'b', 'c', 'd'")
+// scalastyle:on line.size.limit
case class SortArray(base: Expression, ascendingOrder: Expression)
extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
@@ -125,6 +132,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
/**
* Checks if the array (left) has the element (right)
*/
+@ExpressionDescription(
+ usage = "_FUNC_(array, value) - Returns TRUE if the array contains value.",
+ extended = " > SELECT _FUNC_(array(1, 2, 3), 2);\n true")
case class ArrayContains(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index c299586dde..74de4a776d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -27,6 +27,8 @@ import org.apache.spark.unsafe.types.UTF8String
/**
* Returns an Array containing the evaluation of all children expressions.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(n0, ...) - Returns an array with the given elements.")
case class CreateArray(children: Seq[Expression]) extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
@@ -73,6 +75,8 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
* Returns a catalyst Map containing the evaluation of all children expressions as keys and values.
* The children are a flatted sequence of kv pairs, e.g. (key1, value1, key2, value2, ...)
*/
+@ExpressionDescription(
+ usage = "_FUNC_(key0, value0, key1, value1...) - Creates a map with the given key/value pairs.")
case class CreateMap(children: Seq[Expression]) extends Expression {
private[sql] lazy val keys = children.indices.filter(_ % 2 == 0).map(children)
private[sql] lazy val values = children.indices.filter(_ % 2 != 0).map(children)
@@ -153,6 +157,8 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
/**
* Returns a Row containing the evaluation of all children expressions.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.")
case class CreateStruct(children: Seq[Expression]) extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
@@ -204,6 +210,10 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
*
* @param children Seq(name1, val1, name2, val2, ...)
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.")
+// scalastyle:on line.size.limit
case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 103ab365e3..ae6a94842f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -23,7 +23,10 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
-
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(expr1,expr2,expr3) - If expr1 is TRUE then IF() returns expr2; otherwise it returns expr3.")
+// scalastyle:on line.size.limit
case class If(predicate: Expression, trueValue: Expression, falseValue: Expression)
extends Expression {
@@ -85,6 +88,10 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
* @param branches seq of (branch condition, branch value)
* @param elseValue optional value for the else branch
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END - When a = true, returns b; when c = true, return d; else return e.")
+// scalastyle:on line.size.limit
case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None)
extends Expression with CodegenFallback {
@@ -222,7 +229,7 @@ object CaseWhen {
}
/**
- * A factory method to faciliate the creation of this expression when used in parsers.
+ * A factory method to facilitate the creation of this expression when used in parsers.
* @param branches Expressions at even position are the branch conditions, and expressions at odd
* position are branch values.
*/
@@ -256,6 +263,8 @@ object CaseKeyWhen {
* A function that returns the least value of all parameters, skipping null values.
* It takes at least 2 parameters, and returns null iff all parameters are null.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(n1, ...) - Returns the least value of all parameters, skipping null values.")
case class Least(children: Seq[Expression]) extends Expression {
override def nullable: Boolean = children.forall(_.nullable)
@@ -315,6 +324,8 @@ case class Least(children: Seq[Expression]) extends Expression {
* A function that returns the greatest value of all parameters, skipping null values.
* It takes at least 2 parameters, and returns null iff all parameters are null.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(n1, ...) - Returns the greatest value of all parameters, skipping null values.")
case class Greatest(children: Seq[Expression]) extends Expression {
override def nullable: Boolean = children.forall(_.nullable)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index 1d0ea68d7a..9135753041 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -35,6 +35,8 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
*
* There is no code generation since this expression should get constant folded by the optimizer.
*/
+@ExpressionDescription(
+ usage = "_FUNC_() - Returns the current date at the start of query evaluation.")
case class CurrentDate() extends LeafExpression with CodegenFallback {
override def foldable: Boolean = true
override def nullable: Boolean = false
@@ -54,6 +56,8 @@ case class CurrentDate() extends LeafExpression with CodegenFallback {
*
* There is no code generation since this expression should get constant folded by the optimizer.
*/
+@ExpressionDescription(
+ usage = "_FUNC_() - Returns the current timestamp at the start of query evaluation.")
case class CurrentTimestamp() extends LeafExpression with CodegenFallback {
override def foldable: Boolean = true
override def nullable: Boolean = false
@@ -70,6 +74,9 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback {
/**
* Adds a number of days to startdate.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(start_date, num_days) - Returns the date that is num_days after start_date.",
+ extended = "> SELECT _FUNC_('2016-07-30', 1);\n '2016-07-31'")
case class DateAdd(startDate: Expression, days: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -96,6 +103,9 @@ case class DateAdd(startDate: Expression, days: Expression)
/**
* Subtracts a number of days to startdate.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(start_date, num_days) - Returns the date that is num_days before start_date.",
+ extended = "> SELECT _FUNC_('2016-07-30', 1);\n '2016-07-29'")
case class DateSub(startDate: Expression, days: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
override def left: Expression = startDate
@@ -118,6 +128,9 @@ case class DateSub(startDate: Expression, days: Expression)
override def prettyName: String = "date_sub"
}
+@ExpressionDescription(
+ usage = "_FUNC_(param) - Returns the hour component of the string/timestamp/interval.",
+ extended = "> SELECT _FUNC_('2009-07-30 12:58:59');\n 12")
case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
@@ -134,6 +147,9 @@ case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInpu
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(param) - Returns the minute component of the string/timestamp/interval.",
+ extended = "> SELECT _FUNC_('2009-07-30 12:58:59');\n 58")
case class Minute(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
@@ -150,6 +166,9 @@ case class Minute(child: Expression) extends UnaryExpression with ImplicitCastIn
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(param) - Returns the second component of the string/timestamp/interval.",
+ extended = "> SELECT _FUNC_('2009-07-30 12:58:59');\n 59")
case class Second(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
@@ -166,6 +185,9 @@ case class Second(child: Expression) extends UnaryExpression with ImplicitCastIn
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(param) - Returns the day of year of date/timestamp.",
+ extended = "> SELECT _FUNC_('2016-04-09');\n 100")
case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -182,7 +204,9 @@ case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCas
}
}
-
+@ExpressionDescription(
+ usage = "_FUNC_(param) - Returns the year component of the date/timestamp/interval.",
+ extended = "> SELECT _FUNC_('2016-07-30');\n 2016")
case class Year(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -199,6 +223,8 @@ case class Year(child: Expression) extends UnaryExpression with ImplicitCastInpu
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(param) - Returns the quarter of the year for date, in the range 1 to 4.")
case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -215,6 +241,9 @@ case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastI
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(param) - Returns the month component of the date/timestamp/interval",
+ extended = "> SELECT _FUNC_('2016-07-30');\n 7")
case class Month(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -231,6 +260,9 @@ case class Month(child: Expression) extends UnaryExpression with ImplicitCastInp
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(param) - Returns the day of month of date/timestamp, or the day of interval.",
+ extended = "> SELECT _FUNC_('2009-07-30');\n 30")
case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -247,6 +279,9 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(param) - Returns the week of the year of the given date.",
+ extended = "> SELECT _FUNC_('2008-02-20');\n 8")
case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -283,6 +318,11 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa
}
}
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(date/timestamp/string, fmt) - Converts a date/timestamp/string to a value of string in the format specified by the date format fmt.",
+ extended = "> SELECT _FUNC_('2016-04-08', 'y')\n '2016'")
+// scalastyle:on line.size.limit
case class DateFormatClass(left: Expression, right: Expression) extends BinaryExpression
with ImplicitCastInputTypes {
@@ -310,6 +350,8 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx
* Converts time string with given pattern.
* Deterministic version of [[UnixTimestamp]], must have at least one parameter.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(date[, pattern]) - Returns the UNIX timestamp of the give time.")
case class ToUnixTimestamp(timeExp: Expression, format: Expression) extends UnixTime {
override def left: Expression = timeExp
override def right: Expression = format
@@ -331,6 +373,8 @@ case class ToUnixTimestamp(timeExp: Expression, format: Expression) extends Unix
* If the first parameter is a Date or Timestamp instead of String, we will ignore the
* second parameter.
*/
+@ExpressionDescription(
+ usage = "_FUNC_([date[, pattern]]) - Returns the UNIX timestamp of current or specified time.")
case class UnixTimestamp(timeExp: Expression, format: Expression) extends UnixTime {
override def left: Expression = timeExp
override def right: Expression = format
@@ -459,6 +503,9 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes {
* format. If the format is missing, using format like "1970-01-01 00:00:00".
* Note that hive Language Manual says it returns 0 if fail, but in fact it returns null.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(unix_time, format) - Returns unix_time in the specified format",
+ extended = "> SELECT _FUNC_(0, 'yyyy-MM-dd HH:mm:ss');\n '1970-01-01 00:00:00'")
case class FromUnixTime(sec: Expression, format: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -544,6 +591,9 @@ case class FromUnixTime(sec: Expression, format: Expression)
/**
* Returns the last day of the month which the date belongs to.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(date) - Returns the last day of the month which the date belongs to.",
+ extended = "> SELECT _FUNC_('2009-01-12');\n '2009-01-31'")
case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def child: Expression = startDate
@@ -570,6 +620,11 @@ case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitC
*
* Allowed "dayOfWeek" is defined in [[DateTimeUtils.getDayOfWeekFromString]].
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(start_date, day_of_week) - Returns the first date which is later than start_date and named as indicated.",
+ extended = "> SELECT _FUNC_('2015-01-14', 'TU');\n '2015-01-20'")
+// scalastyle:on line.size.limit
case class NextDay(startDate: Expression, dayOfWeek: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -654,6 +709,10 @@ case class TimeAdd(start: Expression, interval: Expression)
/**
* Assumes given timestamp is UTC and converts to given timezone.
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(timestamp, string timezone) - Assumes given timestamp is UTC and converts to given timezone.")
+// scalastyle:on line.size.limit
case class FromUTCTimestamp(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -729,6 +788,9 @@ case class TimeSub(start: Expression, interval: Expression)
/**
* Returns the date that is num_months after start_date.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(start_date, num_months) - Returns the date that is num_months after start_date.",
+ extended = "> SELECT _FUNC_('2016-08-31', 1);\n '2016-09-30'")
case class AddMonths(startDate: Expression, numMonths: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -756,6 +818,9 @@ case class AddMonths(startDate: Expression, numMonths: Expression)
/**
* Returns number of months between dates date1 and date2.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(date1, date2) - returns number of months between dates date1 and date2.",
+ extended = "> SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30');\n 3.94959677")
case class MonthsBetween(date1: Expression, date2: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -783,6 +848,10 @@ case class MonthsBetween(date1: Expression, date2: Expression)
/**
* Assumes given timestamp is in given timezone and converts to UTC.
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(timestamp, string timezone) - Assumes given timestamp is in given timezone and converts to UTC.")
+// scalastyle:on line.size.limit
case class ToUTCTimestamp(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -830,6 +899,9 @@ case class ToUTCTimestamp(left: Expression, right: Expression)
/**
* Returns the date part of a timestamp or string.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(expr) - Extracts the date part of the date or datetime expression expr.",
+ extended = "> SELECT _FUNC_('2009-07-30 04:17:52');\n '2009-07-30'")
case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
// Implicit casting of spark will accept string in both date and timestamp format, as
@@ -850,6 +922,11 @@ case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastIn
/**
* Returns date truncated to the unit specified by the format.
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(date, fmt) - Returns returns date with the time portion of the day truncated to the unit specified by the format model fmt.",
+ extended = "> SELECT _FUNC_('2009-02-12', 'MM')\n '2009-02-01'\n> SELECT _FUNC_('2015-10-27', 'YEAR');\n '2015-01-01'")
+// scalastyle:on line.size.limit
case class TruncDate(date: Expression, format: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
override def left: Expression = date
@@ -921,6 +998,9 @@ case class TruncDate(date: Expression, format: Expression)
/**
* Returns the number of days from startDate to endDate.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(date1, date2) - Returns the number of days between date1 and date2.",
+ extended = "> SELECT _FUNC_('2009-07-30', '2009-07-31');\n 1")
case class DateDiff(endDate: Expression, startDate: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index e7ef21aa85..65d7a1d5a0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -99,6 +99,10 @@ case class UserDefinedGenerator(
/**
* Given an input array produces a sequence of rows for each value in the array.
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Separates the elements of array a into multiple rows, or the elements of a map into multiple rows and columns.")
+// scalastyle:on line.size.limit
case class Explode(child: Expression) extends UnaryExpression with Generator with CodegenFallback {
override def children: Seq[Expression] = child :: Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
index 437e417266..3be761c867 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
@@ -22,8 +22,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.types._
/**
- * A placeholder expression for cube/rollup, which will be replaced by analyzer
- */
+ * A placeholder expression for cube/rollup, which will be replaced by analyzer
+ */
trait GroupingSet extends Expression with CodegenFallback {
def groupByExprs: Seq[Expression]
@@ -43,9 +43,9 @@ case class Cube(groupByExprs: Seq[Expression]) extends GroupingSet {}
case class Rollup(groupByExprs: Seq[Expression]) extends GroupingSet {}
/**
- * Indicates whether a specified column expression in a GROUP BY list is aggregated or not.
- * GROUPING returns 1 for aggregated or 0 for not aggregated in the result set.
- */
+ * Indicates whether a specified column expression in a GROUP BY list is aggregated or not.
+ * GROUPING returns 1 for aggregated or 0 for not aggregated in the result set.
+ */
case class Grouping(child: Expression) extends Expression with Unevaluable {
override def references: AttributeSet = AttributeSet(VirtualColumn.groupingIdAttribute :: Nil)
override def children: Seq[Expression] = child :: Nil
@@ -54,10 +54,10 @@ case class Grouping(child: Expression) extends Expression with Unevaluable {
}
/**
- * GroupingID is a function that computes the level of grouping.
- *
- * If groupByExprs is empty, it means all grouping expressions in GroupingSets.
- */
+ * GroupingID is a function that computes the level of grouping.
+ *
+ * If groupByExprs is empty, it means all grouping expressions in GroupingSets.
+ */
case class GroupingID(groupByExprs: Seq[Expression]) extends Expression with Unevaluable {
override def references: AttributeSet = AttributeSet(VirtualColumn.groupingIdAttribute :: Nil)
override def children: Seq[Expression] = groupByExprs
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index 72b323587c..ecd09b7083 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -106,6 +106,8 @@ private[this] object SharedFactory {
* Extracts json object from a json string based on json path specified, and returns json string
* of the extracted json object. It will return null if the input json string is invalid.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(json_txt, path) - Extract a json object from path")
case class GetJsonObject(json: Expression, path: Expression)
extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
@@ -319,6 +321,10 @@ case class GetJsonObject(json: Expression, path: Expression)
}
}
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(jsonStr, p1, p2, ..., pn) - like get_json_object, but it takes multiple names and return a tuple. All the input parameters and output column types are string.")
+// scalastyle:on line.size.limit
case class JsonTuple(children: Seq[Expression])
extends Generator with CodegenFallback {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index e3d1bc127d..c8a28e8477 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -50,6 +50,7 @@ abstract class LeafMathExpression(c: Double, name: String)
/**
* A unary expression specifically for math functions. Math Functions expect a specific type of
* input format, therefore these functions extend `ExpectsInputTypes`.
+ *
* @param f The math function.
* @param name The short name of the function
*/
@@ -103,6 +104,7 @@ abstract class UnaryLogExpression(f: Double => Double, name: String)
/**
* A binary expression specifically for math functions that take two `Double`s as input and returns
* a `Double`.
+ *
* @param f The math function.
* @param name The short name of the function
*/
@@ -136,12 +138,18 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
* Euler's number. Note that there is no code generation because this is only
* evaluated by the optimizer during constant folding.
*/
+@ExpressionDescription(
+ usage = "_FUNC_() - Returns Euler's number, E.",
+ extended = "> SELECT _FUNC_();\n 2.718281828459045")
case class EulerNumber() extends LeafMathExpression(math.E, "E")
/**
* Pi. Note that there is no code generation because this is only
* evaluated by the optimizer during constant folding.
*/
+@ExpressionDescription(
+ usage = "_FUNC_() - Returns PI.",
+ extended = "> SELECT _FUNC_();\n 3.141592653589793")
case class Pi() extends LeafMathExpression(math.Pi, "PI")
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -150,14 +158,29 @@ case class Pi() extends LeafMathExpression(math.Pi, "PI")
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the arc cosine of x if -1<=x<=1 or NaN otherwise.",
+ extended = "> SELECT _FUNC_(1);\n 0.0\n> SELECT _FUNC_(2);\n NaN")
case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the arc sin of x if -1<=x<=1 or NaN otherwise.",
+ extended = "> SELECT _FUNC_(0);\n 0.0\n> SELECT _FUNC_(2);\n NaN")
case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the arc tangent.",
+ extended = "> SELECT _FUNC_(0);\n 0.0")
case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the cube root of a double value.",
+ extended = "> SELECT _FUNC_(27.0);\n 3.0")
case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the smallest integer not smaller than x.",
+ extended = "> SELECT _FUNC_(-0.1);\n 0\n> SELECT _FUNC_(5);\n 5")
case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") {
override def dataType: DataType = child.dataType match {
case dt @ DecimalType.Fixed(_, 0) => dt
@@ -184,16 +207,26 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL"
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the cosine of x.",
+ extended = "> SELECT _FUNC_(0);\n 1.0")
case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the hyperbolic cosine of x.",
+ extended = "> SELECT _FUNC_(0);\n 1.0")
case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH")
/**
* Convert a num from one base to another
+ *
* @param numExpr the number to be converted
* @param fromBaseExpr from which base
* @param toBaseExpr to which base
*/
+@ExpressionDescription(
+ usage = "_FUNC_(num, from_base, to_base) - Convert num from from_base to to_base.",
+ extended = "> SELECT _FUNC_('100', 2, 10);\n '4'\n> SELECT _FUNC_(-10, 16, -10);\n '16'")
case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
@@ -222,10 +255,19 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns e to the power of x.",
+ extended = "> SELECT _FUNC_(0);\n 1.0")
case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns exp(x) - 1.",
+ extended = "> SELECT _FUNC_(0);\n 0.0")
case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the largest integer not greater than x.",
+ extended = "> SELECT _FUNC_(-0.1);\n -1\n> SELECT _FUNC_(5);\n 5")
case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") {
override def dataType: DataType = child.dataType match {
case dt @ DecimalType.Fixed(_, 0) => dt
@@ -283,6 +325,9 @@ object Factorial {
)
}
+@ExpressionDescription(
+ usage = "_FUNC_(n) - Returns n factorial for n is [0..20]. Otherwise, NULL.",
+ extended = "> SELECT _FUNC_(5);\n 120")
case class Factorial(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[DataType] = Seq(IntegerType)
@@ -315,8 +360,14 @@ case class Factorial(child: Expression) extends UnaryExpression with ImplicitCas
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the natural logarithm of x with base e.",
+ extended = "> SELECT _FUNC_(1);\n 0.0")
case class Log(child: Expression) extends UnaryLogExpression(math.log, "LOG")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the logarithm of x with base 2.",
+ extended = "> SELECT _FUNC_(2);\n 1.0")
case class Log2(child: Expression)
extends UnaryLogExpression((x: Double) => math.log(x) / math.log(2), "LOG2") {
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
@@ -332,36 +383,72 @@ case class Log2(child: Expression)
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the logarithm of x with base 10.",
+ extended = "> SELECT _FUNC_(10);\n 1.0")
case class Log10(child: Expression) extends UnaryLogExpression(math.log10, "LOG10")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns log(1 + x).",
+ extended = "> SELECT _FUNC_(0);\n 0.0")
case class Log1p(child: Expression) extends UnaryLogExpression(math.log1p, "LOG1P") {
protected override val yAsymptote: Double = -1.0
}
+@ExpressionDescription(
+ usage = "_FUNC_(x, d) - Return the rounded x at d decimal places.",
+ extended = "> SELECT _FUNC_(12.3456, 1);\n 12.3")
case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") {
override def funcName: String = "rint"
}
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the sign of x.",
+ extended = "> SELECT _FUNC_(40);\n 1.0")
case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the sine of x.",
+ extended = "> SELECT _FUNC_(0);\n 0.0")
case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the hyperbolic sine of x.",
+ extended = "> SELECT _FUNC_(0);\n 0.0")
case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the square root of x.",
+ extended = "> SELECT _FUNC_(4);\n 2.0")
case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the tangent of x.",
+ extended = "> SELECT _FUNC_(0);\n 0.0")
case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the hyperbolic tangent of x.",
+ extended = "> SELECT _FUNC_(0);\n 0.0")
case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Converts radians to degrees.",
+ extended = "> SELECT _FUNC_(3.141592653589793);\n 180.0")
case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") {
override def funcName: String = "toDegrees"
}
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Converts degrees to radians.",
+ extended = "> SELECT _FUNC_(180);\n 3.141592653589793")
case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") {
override def funcName: String = "toRadians"
}
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns x in binary.",
+ extended = "> SELECT _FUNC_(13);\n '1101'")
case class Bin(child: Expression)
extends UnaryExpression with Serializable with ImplicitCastInputTypes {
@@ -453,6 +540,9 @@ object Hex {
* Otherwise if the number is a STRING, it converts each character into its hex representation
* and returns the resulting STRING. Negative numbers would be treated as two's complement.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Convert the argument to hexadecimal.",
+ extended = "> SELECT _FUNC_(17);\n '11'\n> SELECT _FUNC_('Spark SQL');\n '537061726B2053514C'")
case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] =
@@ -481,6 +571,9 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInput
* Performs the inverse operation of HEX.
* Resulting characters are returned as a byte array.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Converts hexadecimal argument to binary.",
+ extended = "> SELECT decode(_FUNC_('537061726B2053514C'),'UTF-8');\n 'Spark SQL'")
case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
@@ -509,7 +602,9 @@ case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInp
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
-
+@ExpressionDescription(
+ usage = "_FUNC_(x,y) - Returns the arc tangent2.",
+ extended = "> SELECT _FUNC_(0, 0);\n 0.0")
case class Atan2(left: Expression, right: Expression)
extends BinaryMathExpression(math.atan2, "ATAN2") {
@@ -523,6 +618,9 @@ case class Atan2(left: Expression, right: Expression)
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(x1, x2) - Raise x1 to the power of x2.",
+ extended = "> SELECT _FUNC_(2, 3);\n 8.0")
case class Pow(left: Expression, right: Expression)
extends BinaryMathExpression(math.pow, "POWER") {
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
@@ -532,10 +630,14 @@ case class Pow(left: Expression, right: Expression)
/**
- * Bitwise unsigned left shift.
+ * Bitwise left shift.
+ *
* @param left the base number to shift.
* @param right number of bits to left shift.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(a, b) - Bitwise left shift.",
+ extended = "> SELECT _FUNC_(2, 1);\n 4")
case class ShiftLeft(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -558,10 +660,14 @@ case class ShiftLeft(left: Expression, right: Expression)
/**
- * Bitwise unsigned left shift.
+ * Bitwise right shift.
+ *
* @param left the base number to shift.
- * @param right number of bits to left shift.
+ * @param right number of bits to right shift.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(a, b) - Bitwise right shift.",
+ extended = "> SELECT _FUNC_(4, 1);\n 2")
case class ShiftRight(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -585,9 +691,13 @@ case class ShiftRight(left: Expression, right: Expression)
/**
* Bitwise unsigned right shift, for integer and long data type.
+ *
* @param left the base number.
* @param right the number of bits to right shift.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(a, b) - Bitwise unsigned right shift.",
+ extended = "> SELECT _FUNC_(4, 1);\n 2")
case class ShiftRightUnsigned(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -608,16 +718,22 @@ case class ShiftRightUnsigned(left: Expression, right: Expression)
}
}
-
+@ExpressionDescription(
+ usage = "_FUNC_(a, b) - Returns sqrt(a**2 + b**2).",
+ extended = "> SELECT _FUNC_(3, 4);\n 5.0")
case class Hypot(left: Expression, right: Expression)
extends BinaryMathExpression(math.hypot, "HYPOT")
/**
* Computes the logarithm of a number.
+ *
* @param left the logarithm base, default to e.
* @param right the number to compute the logarithm of.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(b, x) - Returns the logarithm of x with base b.",
+ extended = "> SELECT _FUNC_(10, 100);\n 2.0")
case class Logarithm(left: Expression, right: Expression)
extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") {
@@ -674,6 +790,9 @@ case class Logarithm(left: Expression, right: Expression)
* @param child expr to be round, all [[NumericType]] is allowed as Input
* @param scale new scale to be round to, this should be a constant int at runtime
*/
+@ExpressionDescription(
+ usage = "_FUNC_(x, d) - Round x to d decimal places.",
+ extended = "> SELECT _FUNC_(12.3456, 1);\n 12.3")
case class Round(child: Expression, scale: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index e8a3e129b4..4bd918ed01 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -438,6 +438,8 @@ abstract class InterpretedHashFunction {
* We should use this hash function for both shuffle and bucket, so that we can guarantee shuffle
* and bucketing have same data distribution.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(a1, a2, ...) - Returns a hash value of the arguments.")
case class Murmur3Hash(children: Seq[Expression], seed: Int) extends HashExpression[Int] {
def this(arguments: Seq[Expression]) = this(arguments, 42)
@@ -467,8 +469,8 @@ object Murmur3HashFunction extends InterpretedHashFunction {
}
/**
- * Print the result of an expression to stderr (used for debugging codegen).
- */
+ * Print the result of an expression to stderr (used for debugging codegen).
+ */
case class PrintToStderr(child: Expression) extends UnaryExpression {
override def dataType: DataType = child.dataType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index a5b5758167..78310fb2f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -97,7 +97,7 @@ trait NamedExpression extends Expression {
}
}
-abstract class Attribute extends LeafExpression with NamedExpression {
+abstract class Attribute extends LeafExpression with NamedExpression with NullIntolerant {
override def references: AttributeSet = AttributeSet(this)
@@ -329,10 +329,12 @@ case class PrettyAttribute(
override def withName(newName: String): Attribute = throw new UnsupportedOperationException
override def qualifier: Option[String] = throw new UnsupportedOperationException
override def exprId: ExprId = throw new UnsupportedOperationException
- override def nullable: Boolean = throw new UnsupportedOperationException
+ override def nullable: Boolean = true
}
object VirtualColumn {
- val groupingIdName: String = "grouping__id"
+ // The attribute name used by Hive, which has different result than Spark, deprecated.
+ val hiveGroupingIdName: String = "grouping__id"
+ val groupingIdName: String = "spark_grouping_id"
val groupingIdAttribute: UnresolvedAttribute = UnresolvedAttribute(groupingIdName)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index e22026d584..6a45249943 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -34,6 +34,9 @@ import org.apache.spark.sql.types._
* coalesce(null, null, null) => null
* }}}
*/
+@ExpressionDescription(
+ usage = "_FUNC_(a1, a2, ...) - Returns the first non-null argument if exists. Otherwise, NULL.",
+ extended = "> SELECT _FUNC_(NULL, 1, NULL);\n 1")
case class Coalesce(children: Seq[Expression]) extends Expression {
/** Coalesce is nullable if all of its children are nullable, or if it has no children. */
@@ -89,6 +92,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
/**
* Evaluates to `true` iff it's NaN.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Returns true if a is NaN and false otherwise.")
case class IsNaN(child: Expression) extends UnaryExpression
with Predicate with ImplicitCastInputTypes {
@@ -126,6 +131,8 @@ case class IsNaN(child: Expression) extends UnaryExpression
* An Expression evaluates to `left` iff it's not NaN, or evaluates to `right` otherwise.
* This Expression is useful for mapping NaN values to null.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(a,b) - Returns a iff it's not NaN, or b otherwise.")
case class NaNvl(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -180,6 +187,8 @@ case class NaNvl(left: Expression, right: Expression)
/**
* An expression that is evaluated to true if the input is null.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Returns true if a is NULL and false otherwise.")
case class IsNull(child: Expression) extends UnaryExpression with Predicate {
override def nullable: Boolean = false
@@ -201,6 +210,8 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate {
/**
* An expression that is evaluated to true if the input is not null.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Returns true if a is not NULL and false otherwise.")
case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
override def nullable: Boolean = false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 07b67a0240..26b1ff39b3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.expressions
+import java.lang.reflect.Modifier
+
import scala.annotation.tailrec
import scala.language.existentials
import scala.reflect.ClassTag
@@ -112,23 +114,23 @@ case class Invoke(
arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression {
override def nullable: Boolean = true
- override def children: Seq[Expression] = arguments.+:(targetObject)
+ override def children: Seq[Expression] = targetObject +: arguments
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
- lazy val method = targetObject.dataType match {
+ @transient lazy val method = targetObject.dataType match {
case ObjectType(cls) =>
- cls
- .getMethods
- .find(_.getName == functionName)
- .getOrElse(sys.error(s"Couldn't find $functionName on $cls"))
- .getReturnType
- .getName
- case _ => ""
+ val m = cls.getMethods.find(_.getName == functionName)
+ if (m.isEmpty) {
+ sys.error(s"Couldn't find $functionName on $cls")
+ } else {
+ m
+ }
+ case _ => None
}
- lazy val unboxer = (dataType, method) match {
+ lazy val unboxer = (dataType, method.map(_.getReturnType.getName).getOrElse("")) match {
case (IntegerType, "java.lang.Object") => (s: String) =>
s"((java.lang.Integer)$s).intValue()"
case (LongType, "java.lang.Object") => (s: String) =>
@@ -155,21 +157,31 @@ case class Invoke(
// If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
- s"${ev.isNull} = ${ev.value} == null;"
+ s"boolean ${ev.isNull} = ${ev.value} == null;"
} else {
+ ev.isNull = obj.isNull
""
}
val value = unboxer(s"${obj.value}.$functionName($argString)")
+ val evaluate = if (method.forall(_.getExceptionTypes.isEmpty)) {
+ s"$javaType ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value;"
+ } else {
+ s"""
+ $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+ try {
+ ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) $value;
+ } catch (Exception e) {
+ org.apache.spark.unsafe.Platform.throwException(e);
+ }
+ """
+ }
+
s"""
${obj.code}
${argGen.map(_.code).mkString("\n")}
-
- boolean ${ev.isNull} = ${obj.isNull};
- $javaType ${ev.value} =
- ${ev.isNull} ?
- ${ctx.defaultValue(dataType)} : ($javaType) $value;
+ $evaluate
$objNullCheck
"""
}
@@ -214,6 +226,16 @@ case class NewInstance(
override def children: Seq[Expression] = arguments
+ override lazy val resolved: Boolean = {
+ // If the class to construct is an inner class, we need to get its outer pointer, or this
+ // expression should be regarded as unresolved.
+ // Note that static inner classes (e.g., inner classes within Scala objects) don't need
+ // outer pointer registration.
+ val needOuterPointer =
+ outerPointer.isEmpty && cls.isMemberClass && !Modifier.isStatic(cls.getModifiers)
+ childrenResolved && !needOuterPointer
+ }
+
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
@@ -424,6 +446,8 @@ case class MapObjects private(
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val javaType = ctx.javaType(dataType)
val elementJavaType = ctx.javaType(loopVar.dataType)
+ ctx.addMutableState("boolean", loopVar.isNull, "")
+ ctx.addMutableState(elementJavaType, loopVar.value, "")
val genInputData = inputData.gen(ctx)
val genFunction = lambdaFunction.gen(ctx)
val dataLength = ctx.freshName("dataLength")
@@ -444,9 +468,9 @@ case class MapObjects private(
}
val loopNullCheck = if (primitiveElement) {
- s"boolean ${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);"
+ s"${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);"
} else {
- s"boolean ${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;"
+ s"${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;"
}
s"""
@@ -462,7 +486,7 @@ case class MapObjects private(
int $loopIndex = 0;
while ($loopIndex < $dataLength) {
- $elementJavaType ${loopVar.value} =
+ ${loopVar.value} =
($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)};
$loopNullCheck
@@ -502,22 +526,26 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val rowClass = classOf[GenericRowWithSchema].getName
val values = ctx.freshName("values")
- val schemaField = ctx.addReferenceObj("schema", schema)
- s"""
- boolean ${ev.isNull} = false;
- final Object[] $values = new Object[${children.size}];
- """ +
- children.zipWithIndex.map { case (e, i) =>
- val eval = e.gen(ctx)
- eval.code + s"""
+ ctx.addMutableState("Object[]", values, "")
+
+ val childrenCodes = children.zipWithIndex.map { case (e, i) =>
+ val eval = e.gen(ctx)
+ eval.code + s"""
if (${eval.isNull}) {
$values[$i] = null;
} else {
$values[$i] = ${eval.value};
}
"""
- }.mkString("\n") +
- s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField);"
+ }
+ val childrenCode = ctx.splitExpressions(ctx.INPUT_ROW, childrenCodes)
+ val schemaField = ctx.addReferenceObj("schema", schema)
+ s"""
+ boolean ${ev.isNull} = false;
+ $values = new Object[${children.size}];
+ $childrenCode
+ final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField);
+ """
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index f1fa13daa7..23baa6f783 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -92,4 +92,11 @@ package object expressions {
StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable)))
}
}
+
+ /**
+ * When an expression inherits this, meaning the expression is null intolerant (i.e. any null
+ * input will result in null output). We will use this information during constructing IsNotNull
+ * constraints.
+ */
+ trait NullIntolerant
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 20818bfb1a..38f1210a4e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -88,9 +88,10 @@ trait PredicateHelper {
expr.references.subsetOf(plan.outputSet)
}
-
+@ExpressionDescription(
+ usage = "_FUNC_ a - Logical not")
case class Not(child: Expression)
- extends UnaryExpression with Predicate with ImplicitCastInputTypes {
+ extends UnaryExpression with Predicate with ImplicitCastInputTypes with NullIntolerant {
override def toString: String = s"NOT $child"
@@ -109,6 +110,8 @@ case class Not(child: Expression)
/**
* Evaluates to `true` if `list` contains `value`.
*/
+@ExpressionDescription(
+ usage = "expr _FUNC_(val1, val2, ...) - Returns true if expr equals to any valN.")
case class In(value: Expression, list: Seq[Expression]) extends Predicate
with ImplicitCastInputTypes {
@@ -243,6 +246,8 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
}
}
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Logical AND.")
case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate {
override def inputType: AbstractDataType = BooleanType
@@ -274,26 +279,40 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
val eval2 = right.gen(ctx)
// The result should be `false`, if any of them is `false` whenever the other is null or not.
- s"""
- ${eval1.code}
- boolean ${ev.isNull} = false;
- boolean ${ev.value} = false;
+ if (!left.nullable && !right.nullable) {
+ ev.isNull = "false"
+ s"""
+ ${eval1.code}
+ boolean ${ev.value} = false;
- if (!${eval1.isNull} && !${eval1.value}) {
- } else {
- ${eval2.code}
- if (!${eval2.isNull} && !${eval2.value}) {
- } else if (!${eval1.isNull} && !${eval2.isNull}) {
- ${ev.value} = true;
+ if (${eval1.value}) {
+ ${eval2.code}
+ ${ev.value} = ${eval2.value};
+ }
+ """
+ } else {
+ s"""
+ ${eval1.code}
+ boolean ${ev.isNull} = false;
+ boolean ${ev.value} = false;
+
+ if (!${eval1.isNull} && !${eval1.value}) {
} else {
- ${ev.isNull} = true;
+ ${eval2.code}
+ if (!${eval2.isNull} && !${eval2.value}) {
+ } else if (!${eval1.isNull} && !${eval2.isNull}) {
+ ${ev.value} = true;
+ } else {
+ ${ev.isNull} = true;
+ }
}
- }
- """
+ """
+ }
}
}
-
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Logical OR.")
case class Or(left: Expression, right: Expression) extends BinaryOperator with Predicate {
override def inputType: AbstractDataType = BooleanType
@@ -325,22 +344,35 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
val eval2 = right.gen(ctx)
// The result should be `true`, if any of them is `true` whenever the other is null or not.
- s"""
- ${eval1.code}
- boolean ${ev.isNull} = false;
- boolean ${ev.value} = true;
+ if (!left.nullable && !right.nullable) {
+ ev.isNull = "false"
+ s"""
+ ${eval1.code}
+ boolean ${ev.value} = true;
- if (!${eval1.isNull} && ${eval1.value}) {
- } else {
- ${eval2.code}
- if (!${eval2.isNull} && ${eval2.value}) {
- } else if (!${eval1.isNull} && !${eval2.isNull}) {
- ${ev.value} = false;
+ if (!${eval1.value}) {
+ ${eval2.code}
+ ${ev.value} = ${eval2.value};
+ }
+ """
+ } else {
+ s"""
+ ${eval1.code}
+ boolean ${ev.isNull} = false;
+ boolean ${ev.value} = true;
+
+ if (!${eval1.isNull} && ${eval1.value}) {
} else {
- ${ev.isNull} = true;
+ ${eval2.code}
+ if (!${eval2.isNull} && ${eval2.value}) {
+ } else if (!${eval1.isNull} && !${eval2.isNull}) {
+ ${ev.value} = false;
+ } else {
+ ${ev.isNull} = true;
+ }
}
- }
- """
+ """
+ }
}
}
@@ -375,8 +407,10 @@ private[sql] object Equality {
}
}
-
-case class EqualTo(left: Expression, right: Expression) extends BinaryComparison {
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Returns TRUE if a equals b and false otherwise.")
+case class EqualTo(left: Expression, right: Expression)
+ extends BinaryComparison with NullIntolerant {
override def inputType: AbstractDataType = AnyDataType
@@ -399,7 +433,9 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
}
}
-
+@ExpressionDescription(
+ usage = """a _FUNC_ b - Returns same result with EQUAL(=) operator for non-null operands,
+ but returns TRUE if both are NULL, FALSE if one of the them is NULL.""")
case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison {
override def inputType: AbstractDataType = AnyDataType
@@ -440,8 +476,10 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
}
}
-
-case class LessThan(left: Expression, right: Expression) extends BinaryComparison {
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Returns TRUE if a is less than b.")
+case class LessThan(left: Expression, right: Expression)
+ extends BinaryComparison with NullIntolerant {
override def inputType: AbstractDataType = TypeCollection.Ordered
@@ -452,8 +490,10 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2)
}
-
-case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Returns TRUE if a is not greater than b.")
+case class LessThanOrEqual(left: Expression, right: Expression)
+ extends BinaryComparison with NullIntolerant {
override def inputType: AbstractDataType = TypeCollection.Ordered
@@ -464,8 +504,10 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2)
}
-
-case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison {
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Returns TRUE if a is greater than b.")
+case class GreaterThan(left: Expression, right: Expression)
+ extends BinaryComparison with NullIntolerant {
override def inputType: AbstractDataType = TypeCollection.Ordered
@@ -476,8 +518,10 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2)
}
-
-case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Returns TRUE if a is not smaller than b.")
+case class GreaterThanOrEqual(left: Expression, right: Expression)
+ extends BinaryComparison with NullIntolerant {
override def inputType: AbstractDataType = TypeCollection.Ordered
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index 6be3cbcae6..1ec092a5be 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -55,6 +55,8 @@ abstract class RDG extends LeafExpression with Nondeterministic {
}
/** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Returns a random column with i.i.d. uniformly distributed values in [0, 1).")
case class Rand(seed: Long) extends RDG {
override protected def evalInternal(input: InternalRow): Double = rng.nextDouble()
@@ -78,6 +80,8 @@ case class Rand(seed: Long) extends RDG {
}
/** Generate a random column with i.i.d. gaussian random distribution. */
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Returns a random column with i.i.d. gaussian random distribution.")
case class Randn(seed: Long) extends RDG {
override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index b68009331b..85a5429263 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -67,6 +67,8 @@ trait StringRegexExpression extends ImplicitCastInputTypes {
/**
* Simple RegEx pattern matching function
*/
+@ExpressionDescription(
+ usage = "str _FUNC_ pattern - Returns true if str matches pattern and false otherwise.")
case class Like(left: Expression, right: Expression)
extends BinaryExpression with StringRegexExpression {
@@ -117,7 +119,8 @@ case class Like(left: Expression, right: Expression)
}
}
-
+@ExpressionDescription(
+ usage = "str _FUNC_ regexp - Returns true if str matches regexp and false otherwise.")
case class RLike(left: Expression, right: Expression)
extends BinaryExpression with StringRegexExpression {
@@ -169,6 +172,9 @@ case class RLike(left: Expression, right: Expression)
/**
* Splits str around pat (pattern is a regular expression).
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str, regex) - Splits str around occurrences that match regex",
+ extended = "> SELECT _FUNC_('oneAtwoBthreeC', '[ABC]');\n ['one', 'two', 'three']")
case class StringSplit(str: Expression, pattern: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -198,6 +204,9 @@ case class StringSplit(str: Expression, pattern: Expression)
*
* NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str, regexp, rep) - replace all substrings of str that match regexp with rep.",
+ extended = "> SELECT _FUNC_('100-200', '(\\d+)', 'num');\n 'num-num'")
case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
@@ -289,6 +298,9 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
*
* NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str, regexp[, idx]) - extracts a group that matches regexp.",
+ extended = "> SELECT _FUNC_('100-200', '(\\d+)-(\\d+)', 1);\n '100'")
case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
def this(s: Expression, r: Expression) = this(s, r, Literal(1))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index be6b2530ef..93a8278528 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -164,7 +164,7 @@ trait BaseGenericInternalRow extends InternalRow {
abstract class MutableRow extends InternalRow {
def setNullAt(i: Int): Unit
- def update(i: Int, value: Any)
+ def update(i: Int, value: Any): Unit
// default implementation (slow)
def setBoolean(i: Int, value: Boolean): Unit = { update(i, value) }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 3ee19cc4ad..a17482697d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -35,6 +35,9 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
* An expression that concatenates multiple input strings into a single string.
* If any input is null, concat returns null.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of str1, str2, ..., strN",
+ extended = "> SELECT _FUNC_('Spark','SQL');\n 'SparkSQL'")
case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType)
@@ -70,6 +73,10 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas
*
* Returns null if the separator is null. Otherwise, concat_ws skips all null values.
*/
+@ExpressionDescription(
+ usage =
+ "_FUNC_(sep, [str | array(str)]+) - Returns the concatenation of the strings separated by sep.",
+ extended = "> SELECT _FUNC_(' ', Spark', 'SQL');\n 'Spark SQL'")
case class ConcatWs(children: Seq[Expression])
extends Expression with ImplicitCastInputTypes {
@@ -188,7 +195,7 @@ case class Upper(child: Expression)
*/
@ExpressionDescription(
usage = "_FUNC_(str) - Returns str with all characters changed to lowercase",
- extended = "> SELECT _FUNC_('SparkSql');\n'sparksql'")
+ extended = "> SELECT _FUNC_('SparkSql');\n 'sparksql'")
case class Lower(child: Expression) extends UnaryExpression with String2StringExpression {
override def convert(v: UTF8String): UTF8String = v.toLowerCase
@@ -270,6 +277,11 @@ object StringTranslate {
* The translate will happen when any character in the string matching with the character
* in the `matchingExpr`.
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """_FUNC_(input, from, to) - Translates the input string by replacing the characters present in the from string with the corresponding characters in the to string""",
+ extended = "> SELECT _FUNC_('AaBbCc', 'abc', '123');\n 'A1B2C3'")
+// scalastyle:on line.size.limit
case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replaceExpr: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
@@ -325,6 +337,12 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac
* delimited list (right). Returns 0, if the string wasn't found or if the given
* string (left) contains a comma.
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """_FUNC_(str, str_array) - Returns the index (1-based) of the given string (left) in the comma-delimited list (right).
+ Returns 0, if the string wasn't found or if the given string (left) contains a comma.""",
+ extended = "> SELECT _FUNC_('ab','abc,b,ab,c,def');\n 3")
+// scalastyle:on
case class FindInSet(left: Expression, right: Expression) extends BinaryExpression
with ImplicitCastInputTypes {
@@ -347,6 +365,9 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi
/**
* A function that trim the spaces from both ends for the specified string.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str) - Removes the leading and trailing space characters from str.",
+ extended = "> SELECT _FUNC_(' SparkSQL ');\n 'SparkSQL'")
case class StringTrim(child: Expression)
extends UnaryExpression with String2StringExpression {
@@ -362,6 +383,9 @@ case class StringTrim(child: Expression)
/**
* A function that trim the spaces from left end for given string.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str) - Removes the leading space characters from str.",
+ extended = "> SELECT _FUNC_(' SparkSQL ');\n 'SparkSQL '")
case class StringTrimLeft(child: Expression)
extends UnaryExpression with String2StringExpression {
@@ -377,6 +401,9 @@ case class StringTrimLeft(child: Expression)
/**
* A function that trim the spaces from right end for given string.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str) - Removes the trailing space characters from str.",
+ extended = "> SELECT _FUNC_(' SparkSQL ');\n ' SparkSQL'")
case class StringTrimRight(child: Expression)
extends UnaryExpression with String2StringExpression {
@@ -396,6 +423,9 @@ case class StringTrimRight(child: Expression)
*
* NOTE: that this is not zero based, but 1-based index. The first character in str has index 1.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str, substr) - Returns the (1-based) index of the first occurrence of substr in str.",
+ extended = "> SELECT _FUNC_('SparkSQL', 'SQL');\n 6")
case class StringInstr(str: Expression, substr: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -422,6 +452,15 @@ case class StringInstr(str: Expression, substr: Expression)
* returned. If count is negative, every to the right of the final delimiter (counting from the
* right) is returned. substring_index performs a case-sensitive match when searching for delim.
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """_FUNC_(str, delim, count) - Returns the substring from str before count occurrences of the delimiter delim.
+ If count is positive, everything to the left of the final delimiter (counting from the
+ left) is returned. If count is negative, everything to the right of the final delimiter
+ (counting from the right) is returned. Substring_index performs a case-sensitive match
+ when searching for delim.""",
+ extended = "> SELECT _FUNC_('www.apache.org', '.', 2);\n 'www.apache'")
+// scalastyle:on line.size.limit
case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
@@ -445,6 +484,12 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr:
* A function that returns the position of the first occurrence of substr
* in given string after position pos.
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """_FUNC_(substr, str[, pos]) - Returns the position of the first occurrence of substr in str after position pos.
+ The given pos and return value are 1-based.""",
+ extended = "> SELECT _FUNC_('bar', 'foobarbar', 5);\n 7")
+// scalastyle:on line.size.limit
case class StringLocate(substr: Expression, str: Expression, start: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
@@ -510,6 +555,11 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
/**
* Returns str, left-padded with pad to a length of len.
*/
+@ExpressionDescription(
+ usage = """_FUNC_(str, len, pad) - Returns str, left-padded with pad to a length of len.
+ If str is longer than len, the return value is shortened to len characters.""",
+ extended = "> SELECT _FUNC_('hi', 5, '??');\n '???hi'\n" +
+ "> SELECT _FUNC_('hi', 1, '??');\n 'h'")
case class StringLPad(str: Expression, len: Expression, pad: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
@@ -531,6 +581,11 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression)
/**
* Returns str, right-padded with pad to a length of len.
*/
+@ExpressionDescription(
+ usage = """_FUNC_(str, len, pad) - Returns str, right-padded with pad to a length of len.
+ If str is longer than len, the return value is shortened to len characters.""",
+ extended = "> SELECT _FUNC_('hi', 5, '??');\n 'hi???'\n" +
+ "> SELECT _FUNC_('hi', 1, '??');\n 'h'")
case class StringRPad(str: Expression, len: Expression, pad: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
@@ -552,6 +607,11 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression)
/**
* Returns the input formatted according do printf-style format strings
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(String format, Obj... args) - Returns a formatted string from printf-style format strings.",
+ extended = "> SELECT _FUNC_(\"Hello World %d %s\", 100, \"days\");\n 'Hello World 100 days'")
+// scalastyle:on line.size.limit
case class FormatString(children: Expression*) extends Expression with ImplicitCastInputTypes {
require(children.nonEmpty, "format_string() should take at least 1 argument")
@@ -618,25 +678,33 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
}
/**
- * Returns string, with the first letter of each word in uppercase.
+ * Returns string, with the first letter of each word in uppercase, all other letters in lowercase.
* Words are delimited by whitespace.
*/
+@ExpressionDescription(
+ usage =
+ """_FUNC_(str) - Returns str with the first letter of each word in uppercase.
+ All other letters are in lowercase. Words are delimited by white space.""",
+ extended = "> SELECT initcap('sPark sql');\n 'Spark Sql'")
case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[DataType] = Seq(StringType)
override def dataType: DataType = StringType
override def nullSafeEval(string: Any): Any = {
- string.asInstanceOf[UTF8String].toTitleCase
+ string.asInstanceOf[UTF8String].toLowerCase.toTitleCase
}
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
- defineCodeGen(ctx, ev, str => s"$str.toTitleCase()")
+ defineCodeGen(ctx, ev, str => s"$str.toLowerCase().toTitleCase()")
}
}
/**
* Returns the string which repeat the given string value n times.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str, n) - Returns the string which repeat the given string value n times.",
+ extended = "> SELECT _FUNC_('123', 2);\n '123123'")
case class StringRepeat(str: Expression, times: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -659,6 +727,9 @@ case class StringRepeat(str: Expression, times: Expression)
/**
* Returns the reversed given string.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str) - Returns the reversed given string.",
+ extended = "> SELECT _FUNC_('Spark SQL');\n 'LQS krapS'")
case class StringReverse(child: Expression) extends UnaryExpression with String2StringExpression {
override def convert(v: UTF8String): UTF8String = v.reverse()
@@ -672,6 +743,9 @@ case class StringReverse(child: Expression) extends UnaryExpression with String2
/**
* Returns a n spaces string.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(n) - Returns a n spaces string.",
+ extended = "> SELECT _FUNC_(2);\n ' '")
case class StringSpace(child: Expression)
extends UnaryExpression with ImplicitCastInputTypes {
@@ -694,7 +768,14 @@ case class StringSpace(child: Expression)
/**
* A function that takes a substring of its first argument starting at a given position.
* Defined for String and Binary types.
+ *
+ * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1.
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(str, pos[, len]) - Returns the substring of str that starts at pos and is of length len or the slice of byte array that starts at pos and is of length len.",
+ extended = "> SELECT _FUNC_('Spark SQL', 5);\n 'k SQL'\n> SELECT _FUNC_('Spark SQL', -3);\n 'SQL'\n> SELECT _FUNC_('Spark SQL', 5, 1);\n 'k'")
+// scalastyle:on line.size.limit
case class Substring(str: Expression, pos: Expression, len: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
@@ -732,6 +813,9 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
/**
* A function that return the length of the given string or binary expression.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str | binary) - Returns the length of str or number of bytes in binary data.",
+ extended = "> SELECT _FUNC_('Spark SQL');\n 9")
case class Length(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = IntegerType
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType))
@@ -752,6 +836,9 @@ case class Length(child: Expression) extends UnaryExpression with ExpectsInputTy
/**
* A function that return the Levenshtein distance between the two given strings.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str1, str2) - Returns the Levenshtein distance between the two given strings.",
+ extended = "> SELECT _FUNC_('kitten', 'sitting');\n 3")
case class Levenshtein(left: Expression, right: Expression) extends BinaryExpression
with ImplicitCastInputTypes {
@@ -770,6 +857,9 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres
/**
* A function that return soundex code of the given string expression.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str) - Returns soundex code of the string.",
+ extended = "> SELECT _FUNC_('Miller');\n 'M460'")
case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = StringType
@@ -786,6 +876,10 @@ case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputT
/**
* Returns the numeric value of the first character of str.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str) - Returns the numeric value of the first character of str.",
+ extended = "> SELECT _FUNC_('222');\n 50\n" +
+ "> SELECT _FUNC_(2);\n 50")
case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def dataType: DataType = IntegerType
@@ -817,6 +911,8 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp
/**
* Converts the argument from binary to a base 64 string.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(bin) - Convert the argument from binary to a base 64 string.")
case class Base64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def dataType: DataType = StringType
@@ -839,6 +935,8 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn
/**
* Converts the argument from a base 64 string to BINARY.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str) - Convert the argument from a base 64 string to binary.")
case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def dataType: DataType = BinaryType
@@ -860,6 +958,8 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(bin, str) - Decode the first argument using the second argument character set.")
case class Decode(bin: Expression, charset: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -889,7 +989,9 @@ case class Decode(bin: Expression, charset: Expression)
* Encodes the first argument into a BINARY using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null.
-*/
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(str, str) - Encode the first argument using the second argument character set.")
case class Encode(value: Expression, charset: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -919,6 +1021,11 @@ case class Encode(value: Expression, charset: Expression)
* and returns the result as a string. If D is 0, the result has no decimal point or
* fractional part.
*/
+@ExpressionDescription(
+ usage = """_FUNC_(X, D) - Formats the number X like '#,###,###.##', rounded to D decimal places.
+ If D is 0, the result has no decimal point or fractional part.
+ This is supposed to function like MySQL's FORMAT.""",
+ extended = "> SELECT _FUNC_(12332.123456, 4);\n '12,332.1235'")
case class FormatNumber(x: Expression, d: Expression)
extends BinaryExpression with ExpectsInputTypes {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
index b8679474cf..c0b453dccf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
@@ -451,7 +451,11 @@ abstract class RowNumberLike extends AggregateWindowFunction {
* A [[SizeBasedWindowFunction]] needs the size of the current window for its calculation.
*/
trait SizeBasedWindowFunction extends AggregateWindowFunction {
- protected def n: AttributeReference = SizeBasedWindowFunction.n
+ // It's made a val so that the attribute created on driver side is serialized to executor side.
+ // Otherwise, if it's defined as a function, when it's called on executor side, it actually
+ // returns the singleton value instantiated on executor side, which has different expression ID
+ // from the one created on driver side.
+ val n: AttributeReference = SizeBasedWindowFunction.n
}
object SizeBasedWindowFunction {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala
index 87f4d1b007..aae75956ea 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala
@@ -25,10 +25,10 @@ package org.apache.spark.sql.catalyst
* Format (quoted): "`name`" or "`db`.`name`"
*/
sealed trait IdentifierWithDatabase {
- val name: String
+ val identifier: String
def database: Option[String]
- def quotedString: String = database.map(db => s"`$db`.`$name`").getOrElse(s"`$name`")
- def unquotedString: String = database.map(db => s"$db.$name").getOrElse(name)
+ def quotedString: String = database.map(db => s"`$db`.`$identifier`").getOrElse(s"`$identifier`")
+ def unquotedString: String = database.map(db => s"$db.$identifier").getOrElse(identifier)
override def toString: String = quotedString
}
@@ -36,13 +36,15 @@ sealed trait IdentifierWithDatabase {
/**
* Identifies a table in a database.
* If `database` is not defined, the current database is used.
+ * When we register a permenent function in the FunctionRegistry, we use
+ * unquotedString as the function name.
*/
case class TableIdentifier(table: String, database: Option[String])
extends IdentifierWithDatabase {
- override val name: String = table
+ override val identifier: String = table
- def this(name: String) = this(name, None)
+ def this(table: String) = this(table, None)
}
@@ -58,9 +60,9 @@ object TableIdentifier {
case class FunctionIdentifier(funcName: String, database: Option[String])
extends IdentifierWithDatabase {
- override val name: String = funcName
+ override val identifier: String = funcName
- def this(name: String) = this(name, None)
+ def this(funcName: String) = this(funcName, None)
}
object FunctionIdentifier {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index a7a948ef1b..f5172b213a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -31,9 +31,9 @@ import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types._
/**
- * Abstract class all optimizers should inherit of, contains the standard batches (extending
- * Optimizers can override this.
- */
+ * Abstract class all optimizers should inherit of, contains the standard batches (extending
+ * Optimizers can override this.
+ */
abstract class Optimizer extends RuleExecutor[LogicalPlan] {
def batches: Seq[Batch] = {
// Technically some of the rules in Finish Analysis are not optimizer rules and belong more
@@ -66,9 +66,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
ReorderJoin,
OuterJoinElimination,
PushPredicateThroughJoin,
- PushPredicateThroughProject,
- PushPredicateThroughGenerate,
- PushPredicateThroughAggregate,
+ PushDownPredicate,
LimitPushDown,
ColumnPruning,
InferFiltersFromConstraints,
@@ -86,6 +84,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
BooleanSimplification,
SimplifyConditionals,
RemoveDispensableExpressions,
+ BinaryComparisonSimplification,
PruneFilters,
EliminateSorts,
SimplifyCasts,
@@ -93,6 +92,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
EliminateSerialization) ::
Batch("Decimal Optimizations", FixedPoint(100),
DecimalAggregates) ::
+ Batch("Typed Filter Optimization", FixedPoint(100),
+ EmbedSerializerInFilter) ::
Batch("LocalRelation", FixedPoint(100),
ConvertToLocalRelation) ::
Batch("Subquery", Once,
@@ -111,11 +112,11 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
}
/**
- * Non-abstract representation of the standard Spark optimizing strategies
- *
- * To ensure extendability, we leave the standard rules in the abstract optimizer rules, while
- * specific rules go to the subclasses
- */
+ * Non-abstract representation of the standard Spark optimizing strategies
+ *
+ * To ensure extendability, we leave the standard rules in the abstract optimizer rules, while
+ * specific rules go to the subclasses
+ */
object DefaultOptimizer extends Optimizer
/**
@@ -136,6 +137,7 @@ object SamplePushDown extends Rule[LogicalPlan] {
* representation of data item. For example back to back map operations.
*/
object EliminateSerialization extends Rule[LogicalPlan] {
+ // TODO: find a more general way to do this optimization.
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case m @ MapPartitions(_, deserializer, _, child: ObjectOperator)
if !deserializer.isInstanceOf[Attribute] &&
@@ -144,6 +146,20 @@ object EliminateSerialization extends Rule[LogicalPlan] {
m.copy(
deserializer = childWithoutSerialization.output.head,
child = childWithoutSerialization)
+
+ case m @ MapElements(_, deserializer, _, child: ObjectOperator)
+ if !deserializer.isInstanceOf[Attribute] &&
+ deserializer.dataType == child.outputObject.dataType =>
+ val childWithoutSerialization = child.withObjectOutput
+ m.copy(
+ deserializer = childWithoutSerialization.output.head,
+ child = childWithoutSerialization)
+
+ case d @ DeserializeToObject(_, s: SerializeFromObject)
+ if d.outputObjectType == s.inputObjectType =>
+ // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
+ val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId)
+ Project(objAttr :: Nil, s.child)
}
}
@@ -270,10 +286,10 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
assert(children.nonEmpty)
if (projectList.forall(_.deterministic)) {
val newFirstChild = Project(projectList, children.head)
- val newOtherChildren = children.tail.map ( child => {
+ val newOtherChildren = children.tail.map { child =>
val rewrites = buildRewrites(children.head, child)
Project(projectList.map(pushToRight(_, rewrites)), child)
- } )
+ }
Union(newFirstChild +: newOtherChildren)
} else {
p
@@ -352,8 +368,8 @@ object ColumnPruning extends Rule[LogicalPlan] {
case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) =>
p.copy(child = g.copy(join = false))
- // Eliminate unneeded attributes from right side of a LeftSemiJoin.
- case j @ Join(left, right, LeftSemi, condition) =>
+ // Eliminate unneeded attributes from right side of a Left Existence Join.
+ case j @ Join(left, right, LeftExistence(_), condition) =>
j.copy(right = prunedChild(right, j.references))
// all the columns will be used to compare, so we can't prune them
@@ -501,22 +517,28 @@ object LikeSimplification extends Rule[LogicalPlan] {
// Cases like "something\%" are not optimized, but this does not affect correctness.
private val startsWith = "([^_%]+)%".r
private val endsWith = "%([^_%]+)".r
+ private val startsAndEndsWith = "([^_%]+)%([^_%]+)".r
private val contains = "%([^_%]+)%".r
private val equalTo = "([^_%]*)".r
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case Like(l, Literal(utf, StringType)) =>
- utf.toString match {
- case startsWith(pattern) if !pattern.endsWith("\\") =>
- StartsWith(l, Literal(pattern))
- case endsWith(pattern) =>
- EndsWith(l, Literal(pattern))
- case contains(pattern) if !pattern.endsWith("\\") =>
- Contains(l, Literal(pattern))
- case equalTo(pattern) =>
- EqualTo(l, Literal(pattern))
+ case Like(input, Literal(pattern, StringType)) =>
+ pattern.toString match {
+ case startsWith(prefix) if !prefix.endsWith("\\") =>
+ StartsWith(input, Literal(prefix))
+ case endsWith(postfix) =>
+ EndsWith(input, Literal(postfix))
+ // 'a%a' pattern is basically same with 'a%' && '%a'.
+ // However, the additional `Length` condition is required to prevent 'a' match 'a%a'.
+ case startsAndEndsWith(prefix, postfix) if !prefix.endsWith("\\") =>
+ And(GreaterThanOrEqual(Length(input), Literal(prefix.size + postfix.size)),
+ And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix))))
+ case contains(infix) if !infix.endsWith("\\") =>
+ Contains(input, Literal(infix))
+ case equalTo(str) =>
+ EqualTo(input, Literal(str))
case _ =>
- Like(l, Literal.create(utf, StringType))
+ Like(input, Literal.create(pattern, StringType))
}
}
}
@@ -527,14 +549,14 @@ object LikeSimplification extends Rule[LogicalPlan] {
* Null value propagation from bottom to top of the expression tree.
*/
object NullPropagation extends Rule[LogicalPlan] {
- def nonNullLiteral(e: Expression): Boolean = e match {
+ private def nonNullLiteral(e: Expression): Boolean = e match {
case Literal(null, _) => false
case _ => true
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
- case e @ AggregateExpression(Count(exprs), _, _) if !exprs.exists(nonNullLiteral) =>
+ case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) =>
Cast(Literal(0L), e.dataType)
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
@@ -547,9 +569,9 @@ object NullPropagation extends Rule[LogicalPlan] {
Literal.create(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
- case e @ AggregateExpression(Count(exprs), mode, false) if !exprs.exists(_.nullable) =>
+ case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) =>
// This rule should be only triggered when isDistinct field is false.
- AggregateExpression(Count(Literal(1)), mode, isDistinct = false)
+ ae.copy(aggregateFunction = Count(Literal(1)))
// For Coalesce, remove null literals.
case e @ Coalesce(children) =>
@@ -770,20 +792,50 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
}
/**
+ * Simplifies binary comparisons with semantically-equal expressions:
+ * 1) Replace '<=>' with 'true' literal.
+ * 2) Replace '=', '<=', and '>=' with 'true' literal if both operands are non-nullable.
+ * 3) Replace '<' and '>' with 'false' literal if both operands are non-nullable.
+ */
+object BinaryComparisonSimplification extends Rule[LogicalPlan] with PredicateHelper {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case q: LogicalPlan => q transformExpressionsUp {
+ // True with equality
+ case a EqualNullSafe b if a.semanticEquals(b) => TrueLiteral
+ case a EqualTo b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral
+ case a GreaterThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) =>
+ TrueLiteral
+ case a LessThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral
+
+ // False with inequality
+ case a GreaterThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral
+ case a LessThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral
+ }
+ }
+}
+
+/**
* Simplifies conditional expressions (if / case).
*/
object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
+ private def falseOrNullLiteral(e: Expression): Boolean = e match {
+ case FalseLiteral => true
+ case Literal(null, _) => true
+ case _ => false
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case If(TrueLiteral, trueValue, _) => trueValue
case If(FalseLiteral, _, falseValue) => falseValue
+ case If(Literal(null, _), _, falseValue) => falseValue
- case e @ CaseWhen(branches, elseValue) if branches.exists(_._1 == FalseLiteral) =>
+ case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
// If there are branches that are always false, remove them.
// If there are no more branches left, just use the else value.
// Note that these two are handled together here in a single case statement because
// otherwise we cannot determine the data type for the elseValue if it is None (i.e. null).
- val newBranches = branches.filter(_._1 != FalseLiteral)
+ val newBranches = branches.filter(x => !falseOrNullLiteral(x._1))
if (newBranches.isEmpty) {
elseValue.getOrElse(Literal.create(null, e.dataType))
} else {
@@ -869,12 +921,13 @@ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper {
}
/**
- * Pushes [[Filter]] operators through [[Project]] operators, in-lining any [[Alias Aliases]]
- * that were defined in the projection.
+ * Pushes [[Filter]] operators through many operators iff:
+ * 1) the operator is deterministic
+ * 2) the predicate is deterministic and the operator will not change any of rows.
*
* This heuristic is valid assuming the expression evaluation cost is minimal.
*/
-object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelper {
+object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// SPARK-13473: We can't push the predicate down when the underlying projection output non-
// deterministic field(s). Non-deterministic expressions are essentially stateful. This
@@ -891,41 +944,7 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelpe
})
project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild))
- }
-
-}
-
-/**
- * Push [[Filter]] operators through [[Generate]] operators. Parts of the predicate that reference
- * attributes generated in [[Generate]] will remain above, and the rest should be pushed beneath.
- */
-object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper {
-
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case filter @ Filter(condition, g: Generate) =>
- // Predicates that reference attributes produced by the `Generate` operator cannot
- // be pushed below the operator.
- val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond =>
- cond.references.subsetOf(g.child.outputSet) && cond.deterministic
- }
- if (pushDown.nonEmpty) {
- val pushDownPredicate = pushDown.reduce(And)
- val newGenerate = Generate(g.generator, join = g.join, outer = g.outer,
- g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child))
- if (stayUp.isEmpty) newGenerate else Filter(stayUp.reduce(And), newGenerate)
- } else {
- filter
- }
- }
-}
-
-/**
- * Push [[Filter]] operators through [[Aggregate]] operators, iff the filters reference only
- * non-aggregate attributes (typically literals or grouping expressions).
- */
-object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHelper {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case filter @ Filter(condition, aggregate: Aggregate) =>
// Find all the aliased expressions in the aggregate list that don't include any actual
// AggregateExpression, and create a map from the alias to the expression
@@ -951,25 +970,91 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel
} else {
filter
}
+
+ case filter @ Filter(condition, child)
+ if child.isInstanceOf[Union] || child.isInstanceOf[Intersect] =>
+ // Union/Intersect could change the rows, so non-deterministic predicate can't be pushed down
+ val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond =>
+ cond.deterministic
+ }
+ if (pushDown.nonEmpty) {
+ val pushDownCond = pushDown.reduceLeft(And)
+ val output = child.output
+ val newGrandChildren = child.children.map { grandchild =>
+ val newCond = pushDownCond transform {
+ case e if output.exists(_.semanticEquals(e)) =>
+ grandchild.output(output.indexWhere(_.semanticEquals(e)))
+ }
+ assert(newCond.references.subsetOf(grandchild.outputSet))
+ Filter(newCond, grandchild)
+ }
+ val newChild = child.withNewChildren(newGrandChildren)
+ if (stayUp.nonEmpty) {
+ Filter(stayUp.reduceLeft(And), newChild)
+ } else {
+ newChild
+ }
+ } else {
+ filter
+ }
+
+ case filter @ Filter(condition, e @ Except(left, _)) =>
+ pushDownPredicate(filter, e.left) { predicate =>
+ e.copy(left = Filter(predicate, left))
+ }
+
+ // two filters should be combine together by other rules
+ case filter @ Filter(_, f: Filter) => filter
+ // should not push predicates through sample, or will generate different results.
+ case filter @ Filter(_, s: Sample) => filter
+ // TODO: push predicates through expand
+ case filter @ Filter(_, e: Expand) => filter
+
+ case filter @ Filter(condition, u: UnaryNode) if u.expressions.forall(_.deterministic) =>
+ pushDownPredicate(filter, u.child) { predicate =>
+ u.withNewChildren(Seq(Filter(predicate, u.child)))
+ }
+ }
+
+ private def pushDownPredicate(
+ filter: Filter,
+ grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = {
+ // Only push down the predicates that is deterministic and all the referenced attributes
+ // come from grandchild.
+ // TODO: non-deterministic predicates could be pushed through some operators that do not change
+ // the rows.
+ val (pushDown, stayUp) = splitConjunctivePredicates(filter.condition).partition { cond =>
+ cond.deterministic && cond.references.subsetOf(grandchild.outputSet)
+ }
+ if (pushDown.nonEmpty) {
+ val newChild = insertFilter(pushDown.reduceLeft(And))
+ if (stayUp.nonEmpty) {
+ Filter(stayUp.reduceLeft(And), newChild)
+ } else {
+ newChild
+ }
+ } else {
+ filter
+ }
}
}
/**
- * Reorder the joins and push all the conditions into join, so that the bottom ones have at least
- * one condition.
- *
- * The order of joins will not be changed if all of them already have at least one condition.
- */
+ * Reorder the joins and push all the conditions into join, so that the bottom ones have at least
+ * one condition.
+ *
+ * The order of joins will not be changed if all of them already have at least one condition.
+ */
object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
/**
- * Join a list of plans together and push down the conditions into them.
- *
- * The joined plan are picked from left to right, prefer those has at least one join condition.
- *
- * @param input a list of LogicalPlans to join.
- * @param conditions a list of condition for join.
- */
+ * Join a list of plans together and push down the conditions into them.
+ *
+ * The joined plan are picked from left to right, prefer those has at least one join condition.
+ *
+ * @param input a list of LogicalPlans to join.
+ * @param conditions a list of condition for join.
+ */
@tailrec
def createOrderedJoin(input: Seq[LogicalPlan], conditions: Seq[Expression]): LogicalPlan = {
assert(input.size >= 2)
@@ -1110,7 +1195,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
(leftFilterConditions ++ commonFilterCondition).
reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
- case _ @ (LeftOuter | LeftSemi) =>
+ case LeftOuter | LeftExistence(_) =>
// push down the left side only `where` condition
val newLeft = leftFilterConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
@@ -1131,7 +1216,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right)
joinType match {
- case _ @ (Inner | LeftSemi) =>
+ case Inner | LeftExistence(_) =>
// push down the single side only join filter for both sides sub queries
val newLeft = leftJoinConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
@@ -1225,13 +1310,13 @@ object DecimalAggregates extends Rule[LogicalPlan] {
private val MAX_DOUBLE_DIGITS = 15
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), mode, isDistinct)
+ case ae @ AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), _, _, _)
if prec + 10 <= MAX_LONG_DIGITS =>
- MakeDecimal(AggregateExpression(Sum(UnscaledValue(e)), mode, isDistinct), prec + 10, scale)
+ MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale)
- case AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), mode, isDistinct)
+ case ae @ AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), _, _, _)
if prec + 4 <= MAX_DOUBLE_DIGITS =>
- val newAggExpr = AggregateExpression(Average(UnscaledValue(e)), mode, isDistinct)
+ val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e)))
Cast(
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
DecimalType(prec + 4, scale + 4))
@@ -1313,3 +1398,30 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
}
}
}
+
+/**
+ * Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a
+ * [[SerializeFromObject]] above it. If these serializations can't be eliminated, we should embed
+ * the deserializer in filter condition to save the extra serialization at last.
+ */
+object EmbedSerializerInFilter extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case s @ SerializeFromObject(_, Filter(condition, d: DeserializeToObject)) =>
+ val numObjects = condition.collect {
+ case a: Attribute if a == d.output.head => a
+ }.length
+
+ if (numObjects > 1) {
+ // If the filter condition references the object more than one times, we should not embed
+ // deserializer in it as the deserialization will happen many times and slow down the
+ // execution.
+ // TODO: we can still embed it if we can make sure subexpression elimination works here.
+ s
+ } else {
+ val newCondition = condition transform {
+ case a: Attribute if a == d.output.head => d.deserializer.child
+ }
+ Filter(newCondition, d.child)
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala
deleted file mode 100644
index 28f7b10ed6..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala
+++ /dev/null
@@ -1,99 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.sql.catalyst.parser
-
-import org.antlr.runtime.{Token, TokenRewriteStream}
-
-import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode}
-
-case class ASTNode(
- token: Token,
- startIndex: Int,
- stopIndex: Int,
- children: List[ASTNode],
- stream: TokenRewriteStream) extends TreeNode[ASTNode] {
- /** Cache the number of children. */
- val numChildren: Int = children.size
-
- /** tuple used in pattern matching. */
- val pattern: Some[(String, List[ASTNode])] = Some((token.getText, children))
-
- /** Line in which the ASTNode starts. */
- lazy val line: Int = {
- val line = token.getLine
- if (line == 0) {
- if (children.nonEmpty) children.head.line
- else 0
- } else {
- line
- }
- }
-
- /** Position of the Character at which ASTNode starts. */
- lazy val positionInLine: Int = {
- val line = token.getCharPositionInLine
- if (line == -1) {
- if (children.nonEmpty) children.head.positionInLine
- else 0
- } else {
- line
- }
- }
-
- /** Origin of the ASTNode. */
- override val origin: Origin = Origin(Some(line), Some(positionInLine))
-
- /** Source text. */
- lazy val source: String = stream.toOriginalString(startIndex, stopIndex)
-
- /** Get the source text that remains after this token. */
- lazy val remainder: String = {
- stream.fill()
- stream.toOriginalString(stopIndex + 1, stream.size() - 1).trim()
- }
-
- def text: String = token.getText
-
- def tokenType: Int = token.getType
-
- /**
- * Checks if this node is equal to another node.
- *
- * Right now this function only checks the name, type, text and children of the node
- * for equality.
- */
- def treeEquals(other: ASTNode): Boolean = {
- def check(f: ASTNode => Any): Boolean = {
- val l = f(this)
- val r = f(other)
- (l == null && r == null) || l.equals(r)
- }
- if (other == null) {
- false
- } else if (!check(_.token.getType)
- || !check(_.token.getText)
- || !check(_.numChildren)) {
- false
- } else {
- children.zip(other.children).forall {
- case (l, r) => l treeEquals r
- }
- }
- }
-
- override def simpleString: String = s"$text $line, $startIndex, $stopIndex, $positionInLine "
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSparkSQLParser.scala
deleted file mode 100644
index 7b456a6de3..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSparkSQLParser.scala
+++ /dev/null
@@ -1,145 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.parser
-
-import scala.language.implicitConversions
-import scala.util.parsing.combinator.lexical.StdLexical
-import scala.util.parsing.combinator.syntactical.StandardTokenParsers
-import scala.util.parsing.combinator.PackratParsers
-import scala.util.parsing.input.CharArrayReader.EofCh
-
-import org.apache.spark.sql.catalyst.plans.logical._
-
-private[sql] abstract class AbstractSparkSQLParser
- extends StandardTokenParsers with PackratParsers with ParserInterface {
-
- def parsePlan(input: String): LogicalPlan = synchronized {
- // Initialize the Keywords.
- initLexical
- phrase(start)(new lexical.Scanner(input)) match {
- case Success(plan, _) => plan
- case failureOrError => sys.error(failureOrError.toString)
- }
- }
- /* One time initialization of lexical.This avoid reinitialization of lexical in parse method */
- protected lazy val initLexical: Unit = lexical.initialize(reservedWords)
-
- protected case class Keyword(str: String) {
- def normalize: String = lexical.normalizeKeyword(str)
- def parser: Parser[String] = normalize
- }
-
- protected implicit def asParser(k: Keyword): Parser[String] = k.parser
-
- // By default, use Reflection to find the reserved words defined in the sub class.
- // NOTICE, Since the Keyword properties defined by sub class, we couldn't call this
- // method during the parent class instantiation, because the sub class instance
- // isn't created yet.
- protected lazy val reservedWords: Seq[String] =
- this
- .getClass
- .getMethods
- .filter(_.getReturnType == classOf[Keyword])
- .map(_.invoke(this).asInstanceOf[Keyword].normalize)
-
- // Set the keywords as empty by default, will change that later.
- override val lexical = new SqlLexical
-
- protected def start: Parser[LogicalPlan]
-
- // Returns the whole input string
- protected lazy val wholeInput: Parser[String] = new Parser[String] {
- def apply(in: Input): ParseResult[String] =
- Success(in.source.toString, in.drop(in.source.length()))
- }
-
- // Returns the rest of the input string that are not parsed yet
- protected lazy val restInput: Parser[String] = new Parser[String] {
- def apply(in: Input): ParseResult[String] =
- Success(
- in.source.subSequence(in.offset, in.source.length()).toString,
- in.drop(in.source.length()))
- }
-}
-
-class SqlLexical extends StdLexical {
- case class DecimalLit(chars: String) extends Token {
- override def toString: String = chars
- }
-
- /* This is a work around to support the lazy setting */
- def initialize(keywords: Seq[String]): Unit = {
- reserved.clear()
- reserved ++= keywords
- }
-
- /* Normal the keyword string */
- def normalizeKeyword(str: String): String = str.toLowerCase
-
- delimiters += (
- "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")",
- ",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~", "<=>"
- )
-
- protected override def processIdent(name: String) = {
- val token = normalizeKeyword(name)
- if (reserved contains token) Keyword(token) else Identifier(name)
- }
-
- override lazy val token: Parser[Token] =
- ( rep1(digit) ~ scientificNotation ^^ { case i ~ s => DecimalLit(i.mkString + s) }
- | '.' ~> (rep1(digit) ~ scientificNotation) ^^
- { case i ~ s => DecimalLit("0." + i.mkString + s) }
- | rep1(digit) ~ ('.' ~> digit.*) ~ scientificNotation ^^
- { case i1 ~ i2 ~ s => DecimalLit(i1.mkString + "." + i2.mkString + s) }
- | digit.* ~ identChar ~ (identChar | digit).* ^^
- { case first ~ middle ~ rest => processIdent((first ++ (middle :: rest)).mkString) }
- | rep1(digit) ~ ('.' ~> digit.*).? ^^ {
- case i ~ None => NumericLit(i.mkString)
- case i ~ Some(d) => DecimalLit(i.mkString + "." + d.mkString)
- }
- | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^
- { case chars => StringLit(chars mkString "") }
- | '"' ~> chrExcept('"', '\n', EofCh).* <~ '"' ^^
- { case chars => StringLit(chars mkString "") }
- | '`' ~> chrExcept('`', '\n', EofCh).* <~ '`' ^^
- { case chars => Identifier(chars mkString "") }
- | EofCh ^^^ EOF
- | '\'' ~> failure("unclosed string literal")
- | '"' ~> failure("unclosed string literal")
- | delim
- | failure("illegal character")
- )
-
- override def identChar: Parser[Elem] = letter | elem('_')
-
- private lazy val scientificNotation: Parser[String] =
- (elem('e') | elem('E')) ~> (elem('+') | elem('-')).? ~ rep1(digit) ^^ {
- case s ~ rest => "e" + s.mkString + rest.mkString
- }
-
- override def whitespace: Parser[Any] =
- ( whitespaceChar
- | '/' ~ '*' ~ comment
- | '/' ~ '/' ~ chrExcept(EofCh, '\n').*
- | '#' ~ chrExcept(EofCh, '\n').*
- | '-' ~ '-' ~ chrExcept(EofCh, '\n').*
- | '/' ~ '*' ~ failure("unclosed comment")
- ).*
-}
-
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
new file mode 100644
index 0000000000..aa59f3fb2a
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -0,0 +1,1455 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser
+
+import java.sql.{Date, Timestamp}
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+import org.antlr.v4.runtime.{ParserRuleContext, Token}
+import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
+import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
+import org.apache.spark.util.random.RandomSampler
+
+/**
+ * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or
+ * TableIdentifier.
+ */
+class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
+ import ParserUtils._
+
+ protected def typedVisit[T](ctx: ParseTree): T = {
+ ctx.accept(this).asInstanceOf[T]
+ }
+
+ /**
+ * Override the default behavior for all visit methods. This will only return a non-null result
+ * when the context has only one child. This is done because there is no generic method to
+ * combine the results of the context children. In all other cases null is returned.
+ */
+ override def visitChildren(node: RuleNode): AnyRef = {
+ if (node.getChildCount == 1) {
+ node.getChild(0).accept(this)
+ } else {
+ null
+ }
+ }
+
+ override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) {
+ visit(ctx.statement).asInstanceOf[LogicalPlan]
+ }
+
+ override def visitSingleExpression(ctx: SingleExpressionContext): Expression = withOrigin(ctx) {
+ visitNamedExpression(ctx.namedExpression)
+ }
+
+ override def visitSingleTableIdentifier(
+ ctx: SingleTableIdentifierContext): TableIdentifier = withOrigin(ctx) {
+ visitTableIdentifier(ctx.tableIdentifier)
+ }
+
+ override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) {
+ visit(ctx.dataType).asInstanceOf[DataType]
+ }
+
+ /* ********************************************************************************************
+ * Plan parsing
+ * ******************************************************************************************** */
+ protected def plan(tree: ParserRuleContext): LogicalPlan = typedVisit(tree)
+
+ /**
+ * Make sure we do not try to create a plan for a native command.
+ */
+ override def visitExecuteNativeCommand(ctx: ExecuteNativeCommandContext): LogicalPlan = null
+
+ /**
+ * Create a plan for a SHOW FUNCTIONS command.
+ */
+ override def visitShowFunctions(ctx: ShowFunctionsContext): LogicalPlan = withOrigin(ctx) {
+ import ctx._
+ if (qualifiedName != null) {
+ val names = qualifiedName().identifier().asScala.map(_.getText).toList
+ names match {
+ case db :: name :: Nil =>
+ ShowFunctions(Some(db), Some(name))
+ case name :: Nil =>
+ ShowFunctions(None, Some(name))
+ case _ =>
+ throw new ParseException("SHOW FUNCTIONS unsupported name", ctx)
+ }
+ } else if (pattern != null) {
+ ShowFunctions(None, Some(string(pattern)))
+ } else {
+ ShowFunctions(None, None)
+ }
+ }
+
+ /**
+ * Create a plan for a DESCRIBE FUNCTION command.
+ */
+ override def visitDescribeFunction(ctx: DescribeFunctionContext): LogicalPlan = withOrigin(ctx) {
+ val functionName = ctx.qualifiedName().identifier().asScala.map(_.getText).mkString(".")
+ DescribeFunction(functionName, ctx.EXTENDED != null)
+ }
+
+ /**
+ * Create a top-level plan with Common Table Expressions.
+ */
+ override def visitQuery(ctx: QueryContext): LogicalPlan = withOrigin(ctx) {
+ val query = plan(ctx.queryNoWith)
+
+ // Apply CTEs
+ query.optional(ctx.ctes) {
+ val ctes = ctx.ctes.namedQuery.asScala.map {
+ case nCtx =>
+ val namedQuery = visitNamedQuery(nCtx)
+ (namedQuery.alias, namedQuery)
+ }
+
+ // Check for duplicate names.
+ ctes.groupBy(_._1).filter(_._2.size > 1).foreach {
+ case (name, _) =>
+ throw new ParseException(
+ s"Name '$name' is used for multiple common table expressions", ctx)
+ }
+
+ With(query, ctes.toMap)
+ }
+ }
+
+ /**
+ * Create a named logical plan.
+ *
+ * This is only used for Common Table Expressions.
+ */
+ override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) {
+ SubqueryAlias(ctx.name.getText, plan(ctx.queryNoWith))
+ }
+
+ /**
+ * Create a logical plan which allows for multiple inserts using one 'from' statement. These
+ * queries have the following SQL form:
+ * {{{
+ * [WITH cte...]?
+ * FROM src
+ * [INSERT INTO tbl1 SELECT *]+
+ * }}}
+ * For example:
+ * {{{
+ * FROM db.tbl1 A
+ * INSERT INTO dbo.tbl1 SELECT * WHERE A.value = 10 LIMIT 5
+ * INSERT INTO dbo.tbl2 SELECT * WHERE A.value = 12
+ * }}}
+ * This (Hive) feature cannot be combined with set-operators.
+ */
+ override def visitMultiInsertQuery(ctx: MultiInsertQueryContext): LogicalPlan = withOrigin(ctx) {
+ val from = visitFromClause(ctx.fromClause)
+
+ // Build the insert clauses.
+ val inserts = ctx.multiInsertQueryBody.asScala.map {
+ body =>
+ assert(body.querySpecification.fromClause == null,
+ "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements",
+ body)
+
+ withQuerySpecification(body.querySpecification, from).
+ // Add organization statements.
+ optionalMap(body.queryOrganization)(withQueryResultClauses).
+ // Add insert.
+ optionalMap(body.insertInto())(withInsertInto)
+ }
+
+ // If there are multiple INSERTS just UNION them together into one query.
+ inserts match {
+ case Seq(query) => query
+ case queries => Union(queries)
+ }
+ }
+
+ /**
+ * Create a logical plan for a regular (single-insert) query.
+ */
+ override def visitSingleInsertQuery(
+ ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) {
+ plan(ctx.queryTerm).
+ // Add organization statements.
+ optionalMap(ctx.queryOrganization)(withQueryResultClauses).
+ // Add insert.
+ optionalMap(ctx.insertInto())(withInsertInto)
+ }
+
+ /**
+ * Add an INSERT INTO [TABLE]/INSERT OVERWRITE TABLE operation to the logical plan.
+ */
+ private def withInsertInto(
+ ctx: InsertIntoContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ val tableIdent = visitTableIdentifier(ctx.tableIdentifier)
+ val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)
+
+ InsertIntoTable(
+ UnresolvedRelation(tableIdent, None),
+ partitionKeys,
+ query,
+ ctx.OVERWRITE != null,
+ ctx.EXISTS != null)
+ }
+
+ /**
+ * Create a partition specification map.
+ */
+ override def visitPartitionSpec(
+ ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) {
+ ctx.partitionVal.asScala.map { pVal =>
+ val name = pVal.identifier.getText.toLowerCase
+ val value = Option(pVal.constant).map(visitStringConstant)
+ name -> value
+ }.toMap
+ }
+
+ /**
+ * Create a partition specification map without optional values.
+ */
+ protected def visitNonOptionalPartitionSpec(
+ ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) {
+ visitPartitionSpec(ctx).mapValues(_.orNull).map(identity)
+ }
+
+ /**
+ * Convert a constant of any type into a string. This is typically used in DDL commands, and its
+ * main purpose is to prevent slight differences due to back to back conversions i.e.:
+ * String -> Literal -> String.
+ */
+ protected def visitStringConstant(ctx: ConstantContext): String = withOrigin(ctx) {
+ ctx match {
+ case s: StringLiteralContext => createString(s)
+ case o => o.getText
+ }
+ }
+
+ /**
+ * Add ORDER BY/SORT BY/CLUSTER BY/DISTRIBUTE BY/LIMIT/WINDOWS clauses to the logical plan. These
+ * clauses determine the shape (ordering/partitioning/rows) of the query result.
+ */
+ private def withQueryResultClauses(
+ ctx: QueryOrganizationContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ import ctx._
+
+ // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause.
+ val withOrder = if (
+ !order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
+ // ORDER BY ...
+ Sort(order.asScala.map(visitSortItem), global = true, query)
+ } else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
+ // SORT BY ...
+ Sort(sort.asScala.map(visitSortItem), global = false, query)
+ } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) {
+ // DISTRIBUTE BY ...
+ RepartitionByExpression(expressionList(distributeBy), query)
+ } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) {
+ // SORT BY ... DISTRIBUTE BY ...
+ Sort(
+ sort.asScala.map(visitSortItem),
+ global = false,
+ RepartitionByExpression(expressionList(distributeBy), query))
+ } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) {
+ // CLUSTER BY ...
+ val expressions = expressionList(clusterBy)
+ Sort(
+ expressions.map(SortOrder(_, Ascending)),
+ global = false,
+ RepartitionByExpression(expressions, query))
+ } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
+ // [EMPTY]
+ query
+ } else {
+ throw new ParseException(
+ "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", ctx)
+ }
+
+ // WINDOWS
+ val withWindow = withOrder.optionalMap(windows)(withWindows)
+
+ // LIMIT
+ withWindow.optional(limit) {
+ Limit(typedVisit(limit), withWindow)
+ }
+ }
+
+ /**
+ * Create a logical plan using a query specification.
+ */
+ override def visitQuerySpecification(
+ ctx: QuerySpecificationContext): LogicalPlan = withOrigin(ctx) {
+ val from = OneRowRelation.optional(ctx.fromClause) {
+ visitFromClause(ctx.fromClause)
+ }
+ withQuerySpecification(ctx, from)
+ }
+
+ /**
+ * Add a query specification to a logical plan. The query specification is the core of the logical
+ * plan, this is where sourcing (FROM clause), transforming (SELECT TRANSFORM/MAP/REDUCE),
+ * projection (SELECT), aggregation (GROUP BY ... HAVING ...) and filtering (WHERE) takes place.
+ *
+ * Note that query hints are ignored (both by the parser and the builder).
+ */
+ private def withQuerySpecification(
+ ctx: QuerySpecificationContext,
+ relation: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ import ctx._
+
+ // WHERE
+ def filter(ctx: BooleanExpressionContext, plan: LogicalPlan): LogicalPlan = {
+ Filter(expression(ctx), plan)
+ }
+
+ // Expressions.
+ val expressions = Option(namedExpressionSeq).toSeq
+ .flatMap(_.namedExpression.asScala)
+ .map(typedVisit[Expression])
+
+ // Create either a transform or a regular query.
+ val specType = Option(kind).map(_.getType).getOrElse(SqlBaseParser.SELECT)
+ specType match {
+ case SqlBaseParser.MAP | SqlBaseParser.REDUCE | SqlBaseParser.TRANSFORM =>
+ // Transform
+
+ // Add where.
+ val withFilter = relation.optionalMap(where)(filter)
+
+ // Create the attributes.
+ val (attributes, schemaLess) = if (colTypeList != null) {
+ // Typed return columns.
+ (createStructType(colTypeList).toAttributes, false)
+ } else if (identifierSeq != null) {
+ // Untyped return columns.
+ val attrs = visitIdentifierSeq(identifierSeq).map { name =>
+ AttributeReference(name, StringType, nullable = true)()
+ }
+ (attrs, false)
+ } else {
+ (Seq(AttributeReference("key", StringType)(),
+ AttributeReference("value", StringType)()), true)
+ }
+
+ // Create the transform.
+ ScriptTransformation(
+ expressions,
+ string(script),
+ attributes,
+ withFilter,
+ withScriptIOSchema(
+ ctx, inRowFormat, recordWriter, outRowFormat, recordReader, schemaLess))
+
+ case SqlBaseParser.SELECT =>
+ // Regular select
+
+ // Add lateral views.
+ val withLateralView = ctx.lateralView.asScala.foldLeft(relation)(withGenerate)
+
+ // Add where.
+ val withFilter = withLateralView.optionalMap(where)(filter)
+
+ // Add aggregation or a project.
+ val namedExpressions = expressions.map {
+ case e: NamedExpression => e
+ case e: Expression => UnresolvedAlias(e)
+ }
+ val withProject = if (aggregation != null) {
+ withAggregation(aggregation, namedExpressions, withFilter)
+ } else if (namedExpressions.nonEmpty) {
+ Project(namedExpressions, withFilter)
+ } else {
+ withFilter
+ }
+
+ // Having
+ val withHaving = withProject.optional(having) {
+ // Note that we added a cast to boolean. If the expression itself is already boolean,
+ // the optimizer will get rid of the unnecessary cast.
+ Filter(Cast(expression(having), BooleanType), withProject)
+ }
+
+ // Distinct
+ val withDistinct = if (setQuantifier() != null && setQuantifier().DISTINCT() != null) {
+ Distinct(withHaving)
+ } else {
+ withHaving
+ }
+
+ // Window
+ withDistinct.optionalMap(windows)(withWindows)
+ }
+ }
+
+ /**
+ * Create a (Hive based) [[ScriptInputOutputSchema]].
+ */
+ protected def withScriptIOSchema(
+ ctx: QuerySpecificationContext,
+ inRowFormat: RowFormatContext,
+ recordWriter: Token,
+ outRowFormat: RowFormatContext,
+ recordReader: Token,
+ schemaLess: Boolean): ScriptInputOutputSchema = {
+ throw new ParseException("Script Transform is not supported", ctx)
+ }
+
+ /**
+ * Create a logical plan for a given 'FROM' clause. Note that we support multiple (comma
+ * separated) relations here, these get converted into a single plan by condition-less inner join.
+ */
+ override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) {
+ val from = ctx.relation.asScala.map(plan).reduceLeft(Join(_, _, Inner, None))
+ ctx.lateralView.asScala.foldLeft(from)(withGenerate)
+ }
+
+ /**
+ * Connect two queries by a Set operator.
+ *
+ * Supported Set operators are:
+ * - UNION [DISTINCT]
+ * - UNION ALL
+ * - EXCEPT [DISTINCT]
+ * - INTERSECT [DISTINCT]
+ */
+ override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) {
+ val left = plan(ctx.left)
+ val right = plan(ctx.right)
+ val all = Option(ctx.setQuantifier()).exists(_.ALL != null)
+ ctx.operator.getType match {
+ case SqlBaseParser.UNION if all =>
+ Union(left, right)
+ case SqlBaseParser.UNION =>
+ Distinct(Union(left, right))
+ case SqlBaseParser.INTERSECT if all =>
+ throw new ParseException("INTERSECT ALL is not supported.", ctx)
+ case SqlBaseParser.INTERSECT =>
+ Intersect(left, right)
+ case SqlBaseParser.EXCEPT if all =>
+ throw new ParseException("EXCEPT ALL is not supported.", ctx)
+ case SqlBaseParser.EXCEPT =>
+ Except(left, right)
+ }
+ }
+
+ /**
+ * Add a [[WithWindowDefinition]] operator to a logical plan.
+ */
+ private def withWindows(
+ ctx: WindowsContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ // Collect all window specifications defined in the WINDOW clause.
+ val baseWindowMap = ctx.namedWindow.asScala.map {
+ wCtx =>
+ (wCtx.identifier.getText, typedVisit[WindowSpec](wCtx.windowSpec))
+ }.toMap
+
+ // Handle cases like
+ // window w1 as (partition by p_mfgr order by p_name
+ // range between 2 preceding and 2 following),
+ // w2 as w1
+ val windowMapView = baseWindowMap.mapValues {
+ case WindowSpecReference(name) =>
+ baseWindowMap.get(name) match {
+ case Some(spec: WindowSpecDefinition) =>
+ spec
+ case Some(ref) =>
+ throw new ParseException(s"Window reference '$name' is not a window specification", ctx)
+ case None =>
+ throw new ParseException(s"Cannot resolve window reference '$name'", ctx)
+ }
+ case spec: WindowSpecDefinition => spec
+ }
+
+ // Note that mapValues creates a view instead of materialized map. We force materialization by
+ // mapping over identity.
+ WithWindowDefinition(windowMapView.map(identity), query)
+ }
+
+ /**
+ * Add an [[Aggregate]] to a logical plan.
+ */
+ private def withAggregation(
+ ctx: AggregationContext,
+ selectExpressions: Seq[NamedExpression],
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ import ctx._
+ val groupByExpressions = expressionList(groupingExpressions)
+
+ if (GROUPING != null) {
+ // GROUP BY .... GROUPING SETS (...)
+ val expressionMap = groupByExpressions.zipWithIndex.toMap
+ val numExpressions = expressionMap.size
+ val mask = (1 << numExpressions) - 1
+ val masks = ctx.groupingSet.asScala.map {
+ _.expression.asScala.foldLeft(mask) {
+ case (bitmap, eCtx) =>
+ // Find the index of the expression.
+ val e = typedVisit[Expression](eCtx)
+ val index = expressionMap.find(_._1.semanticEquals(e)).map(_._2).getOrElse(
+ throw new ParseException(
+ s"$e doesn't show up in the GROUP BY list", ctx))
+ // 0 means that the column at the given index is a grouping column, 1 means it is not,
+ // so we unset the bit in bitmap.
+ bitmap & ~(1 << (numExpressions - 1 - index))
+ }
+ }
+ GroupingSets(masks, groupByExpressions, query, selectExpressions)
+ } else {
+ // GROUP BY .... (WITH CUBE | WITH ROLLUP)?
+ val mappedGroupByExpressions = if (CUBE != null) {
+ Seq(Cube(groupByExpressions))
+ } else if (ROLLUP != null) {
+ Seq(Rollup(groupByExpressions))
+ } else {
+ groupByExpressions
+ }
+ Aggregate(mappedGroupByExpressions, selectExpressions, query)
+ }
+ }
+
+ /**
+ * Add a [[Generate]] (Lateral View) to a logical plan.
+ */
+ private def withGenerate(
+ query: LogicalPlan,
+ ctx: LateralViewContext): LogicalPlan = withOrigin(ctx) {
+ val expressions = expressionList(ctx.expression)
+
+ // Create the generator.
+ val generator = ctx.qualifiedName.getText.toLowerCase match {
+ case "explode" if expressions.size == 1 =>
+ Explode(expressions.head)
+ case "json_tuple" =>
+ JsonTuple(expressions)
+ case name =>
+ UnresolvedGenerator(name, expressions)
+ }
+
+ Generate(
+ generator,
+ join = true,
+ outer = ctx.OUTER != null,
+ Some(ctx.tblName.getText.toLowerCase),
+ ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.apply),
+ query)
+ }
+
+ /**
+ * Create a joins between two or more logical plans.
+ */
+ override def visitJoinRelation(ctx: JoinRelationContext): LogicalPlan = withOrigin(ctx) {
+ /** Build a join between two plans. */
+ def join(ctx: JoinRelationContext, left: LogicalPlan, right: LogicalPlan): Join = {
+ val baseJoinType = ctx.joinType match {
+ case null => Inner
+ case jt if jt.FULL != null => FullOuter
+ case jt if jt.SEMI != null => LeftSemi
+ case jt if jt.ANTI != null => LeftAnti
+ case jt if jt.LEFT != null => LeftOuter
+ case jt if jt.RIGHT != null => RightOuter
+ case _ => Inner
+ }
+
+ // Resolve the join type and join condition
+ val (joinType, condition) = Option(ctx.joinCriteria) match {
+ case Some(c) if c.USING != null =>
+ val columns = c.identifier.asScala.map { column =>
+ UnresolvedAttribute.quoted(column.getText)
+ }
+ (UsingJoin(baseJoinType, columns), None)
+ case Some(c) if c.booleanExpression != null =>
+ (baseJoinType, Option(expression(c.booleanExpression)))
+ case None if ctx.NATURAL != null =>
+ (NaturalJoin(baseJoinType), None)
+ case None =>
+ (baseJoinType, None)
+ }
+ Join(left, right, joinType, condition)
+ }
+
+ // Handle all consecutive join clauses. ANTLR produces a right nested tree in which the the
+ // first join clause is at the top. However fields of previously referenced tables can be used
+ // in following join clauses. The tree needs to be reversed in order to make this work.
+ var result = plan(ctx.left)
+ var current = ctx
+ while (current != null) {
+ current.right match {
+ case right: JoinRelationContext =>
+ result = join(current, result, plan(right.left))
+ current = right
+ case right =>
+ result = join(current, result, plan(right))
+ current = null
+ }
+ }
+ result
+ }
+
+ /**
+ * Add a [[Sample]] to a logical plan.
+ *
+ * This currently supports the following sampling methods:
+ * - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows.
+ * - TABLESAMPLE(x PERCENT): Sample the table down to the given percentage. Note that percentages
+ * are defined as a number between 0 and 100.
+ * - TABLESAMPLE(BUCKET x OUT OF y): Sample the table down to a 'x' divided by 'y' fraction.
+ */
+ private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ // Create a sampled plan if we need one.
+ def sample(fraction: Double): Sample = {
+ // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling
+ // function takes X PERCENT as the input and the range of X is [0, 100], we need to
+ // adjust the fraction.
+ val eps = RandomSampler.roundingEpsilon
+ assert(fraction >= 0.0 - eps && fraction <= 1.0 + eps,
+ s"Sampling fraction ($fraction) must be on interval [0, 1]",
+ ctx)
+ Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query)(true)
+ }
+
+ ctx.sampleType.getType match {
+ case SqlBaseParser.ROWS =>
+ Limit(expression(ctx.expression), query)
+
+ case SqlBaseParser.PERCENTLIT =>
+ val fraction = ctx.percentage.getText.toDouble
+ sample(fraction / 100.0d)
+
+ case SqlBaseParser.BUCKET if ctx.ON != null =>
+ throw new ParseException("TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported", ctx)
+
+ case SqlBaseParser.BUCKET =>
+ sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble)
+ }
+ }
+
+ /**
+ * Create a logical plan for a sub-query.
+ */
+ override def visitSubquery(ctx: SubqueryContext): LogicalPlan = withOrigin(ctx) {
+ plan(ctx.queryNoWith)
+ }
+
+ /**
+ * Create an un-aliased table reference. This is typically used for top-level table references,
+ * for example:
+ * {{{
+ * INSERT INTO db.tbl2
+ * TABLE db.tbl1
+ * }}}
+ */
+ override def visitTable(ctx: TableContext): LogicalPlan = withOrigin(ctx) {
+ UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier), None)
+ }
+
+ /**
+ * Create an aliased table reference. This is typically used in FROM clauses.
+ */
+ override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) {
+ val table = UnresolvedRelation(
+ visitTableIdentifier(ctx.tableIdentifier),
+ Option(ctx.identifier).map(_.getText))
+ table.optionalMap(ctx.sample)(withSample)
+ }
+
+ /**
+ * Create an inline table (a virtual table in Hive parlance).
+ */
+ override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) {
+ // Get the backing expressions.
+ val expressions = ctx.expression.asScala.map { eCtx =>
+ val e = expression(eCtx)
+ assert(e.foldable, "All expressions in an inline table must be constants.", eCtx)
+ e
+ }
+
+ // Validate and evaluate the rows.
+ val (structType, structConstructor) = expressions.head.dataType match {
+ case st: StructType =>
+ (st, (e: Expression) => e)
+ case dt =>
+ val st = CreateStruct(Seq(expressions.head)).dataType
+ (st, (e: Expression) => CreateStruct(Seq(e)))
+ }
+ val rows = expressions.map {
+ case expression =>
+ val safe = Cast(structConstructor(expression), structType)
+ safe.eval().asInstanceOf[InternalRow]
+ }
+
+ // Construct attributes.
+ val baseAttributes = structType.toAttributes.map(_.withNullability(true))
+ val attributes = if (ctx.identifierList != null) {
+ val aliases = visitIdentifierList(ctx.identifierList)
+ assert(aliases.size == baseAttributes.size,
+ "Number of aliases must match the number of fields in an inline table.", ctx)
+ baseAttributes.zip(aliases).map(p => p._1.withName(p._2))
+ } else {
+ baseAttributes
+ }
+
+ // Create plan and add an alias if a name has been defined.
+ LocalRelation(attributes, rows).optionalMap(ctx.identifier)(aliasPlan)
+ }
+
+ /**
+ * Create an alias (SubqueryAlias) for a join relation. This is practically the same as
+ * visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different
+ * hooks.
+ */
+ override def visitAliasedRelation(ctx: AliasedRelationContext): LogicalPlan = withOrigin(ctx) {
+ plan(ctx.relation).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan)
+ }
+
+ /**
+ * Create an alias (SubqueryAlias) for a sub-query. This is practically the same as
+ * visitAliasedRelation and visitNamedExpression, ANTLR4 however requires us to use 3 different
+ * hooks.
+ */
+ override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) {
+ plan(ctx.queryNoWith).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan)
+ }
+
+ /**
+ * Create an alias (SubqueryAlias) for a LogicalPlan.
+ */
+ private def aliasPlan(alias: IdentifierContext, plan: LogicalPlan): LogicalPlan = {
+ SubqueryAlias(alias.getText, plan)
+ }
+
+ /**
+ * Create a Sequence of Strings for a parenthesis enclosed alias list.
+ */
+ override def visitIdentifierList(ctx: IdentifierListContext): Seq[String] = withOrigin(ctx) {
+ visitIdentifierSeq(ctx.identifierSeq)
+ }
+
+ /**
+ * Create a Sequence of Strings for an identifier list.
+ */
+ override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) {
+ ctx.identifier.asScala.map(_.getText)
+ }
+
+ /* ********************************************************************************************
+ * Table Identifier parsing
+ * ******************************************************************************************** */
+ /**
+ * Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern.
+ */
+ override def visitTableIdentifier(
+ ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) {
+ TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText))
+ }
+
+ /* ********************************************************************************************
+ * Expression parsing
+ * ******************************************************************************************** */
+ /**
+ * Create an expression from the given context. This method just passes the context on to the
+ * vistor and only takes care of typing (We assume that the visitor returns an Expression here).
+ */
+ protected def expression(ctx: ParserRuleContext): Expression = typedVisit(ctx)
+
+ /**
+ * Create sequence of expressions from the given sequence of contexts.
+ */
+ private def expressionList(trees: java.util.List[ExpressionContext]): Seq[Expression] = {
+ trees.asScala.map(expression)
+ }
+
+ /**
+ * Create a star (i.e. all) expression; this selects all elements (in the specified object).
+ * Both un-targeted (global) and targeted aliases are supported.
+ */
+ override def visitStar(ctx: StarContext): Expression = withOrigin(ctx) {
+ UnresolvedStar(Option(ctx.qualifiedName()).map(_.identifier.asScala.map(_.getText)))
+ }
+
+ /**
+ * Create an aliased expression if an alias is specified. Both single and multi-aliases are
+ * supported.
+ */
+ override def visitNamedExpression(ctx: NamedExpressionContext): Expression = withOrigin(ctx) {
+ val e = expression(ctx.expression)
+ if (ctx.identifier != null) {
+ Alias(e, ctx.identifier.getText)()
+ } else if (ctx.identifierList != null) {
+ MultiAlias(e, visitIdentifierList(ctx.identifierList))
+ } else {
+ e
+ }
+ }
+
+ /**
+ * Combine a number of boolean expressions into a balanced expression tree. These expressions are
+ * either combined by a logical [[And]] or a logical [[Or]].
+ *
+ * A balanced binary tree is created because regular left recursive trees cause considerable
+ * performance degradations and can cause stack overflows.
+ */
+ override def visitLogicalBinary(ctx: LogicalBinaryContext): Expression = withOrigin(ctx) {
+ val expressionType = ctx.operator.getType
+ val expressionCombiner = expressionType match {
+ case SqlBaseParser.AND => And.apply _
+ case SqlBaseParser.OR => Or.apply _
+ }
+
+ // Collect all similar left hand contexts.
+ val contexts = ArrayBuffer(ctx.right)
+ var current = ctx.left
+ def collectContexts: Boolean = current match {
+ case lbc: LogicalBinaryContext if lbc.operator.getType == expressionType =>
+ contexts += lbc.right
+ current = lbc.left
+ true
+ case _ =>
+ contexts += current
+ false
+ }
+ while (collectContexts) {
+ // No body - all updates take place in the collectContexts.
+ }
+
+ // Reverse the contexts to have them in the same sequence as in the SQL statement & turn them
+ // into expressions.
+ val expressions = contexts.reverse.map(expression)
+
+ // Create a balanced tree.
+ def reduceToExpressionTree(low: Int, high: Int): Expression = high - low match {
+ case 0 =>
+ expressions(low)
+ case 1 =>
+ expressionCombiner(expressions(low), expressions(high))
+ case x =>
+ val mid = low + x / 2
+ expressionCombiner(
+ reduceToExpressionTree(low, mid),
+ reduceToExpressionTree(mid + 1, high))
+ }
+ reduceToExpressionTree(0, expressions.size - 1)
+ }
+
+ /**
+ * Invert a boolean expression.
+ */
+ override def visitLogicalNot(ctx: LogicalNotContext): Expression = withOrigin(ctx) {
+ Not(expression(ctx.booleanExpression()))
+ }
+
+ /**
+ * Create a filtering correlated sub-query. This is not supported yet.
+ */
+ override def visitExists(ctx: ExistsContext): Expression = {
+ throw new ParseException("EXISTS clauses are not supported.", ctx)
+ }
+
+ /**
+ * Create a comparison expression. This compares two expressions. The following comparison
+ * operators are supported:
+ * - Equal: '=' or '=='
+ * - Null-safe Equal: '<=>'
+ * - Not Equal: '<>' or '!='
+ * - Less than: '<'
+ * - Less then or Equal: '<='
+ * - Greater than: '>'
+ * - Greater then or Equal: '>='
+ */
+ override def visitComparison(ctx: ComparisonContext): Expression = withOrigin(ctx) {
+ val left = expression(ctx.left)
+ val right = expression(ctx.right)
+ val operator = ctx.comparisonOperator().getChild(0).asInstanceOf[TerminalNode]
+ operator.getSymbol.getType match {
+ case SqlBaseParser.EQ =>
+ EqualTo(left, right)
+ case SqlBaseParser.NSEQ =>
+ EqualNullSafe(left, right)
+ case SqlBaseParser.NEQ | SqlBaseParser.NEQJ =>
+ Not(EqualTo(left, right))
+ case SqlBaseParser.LT =>
+ LessThan(left, right)
+ case SqlBaseParser.LTE =>
+ LessThanOrEqual(left, right)
+ case SqlBaseParser.GT =>
+ GreaterThan(left, right)
+ case SqlBaseParser.GTE =>
+ GreaterThanOrEqual(left, right)
+ }
+ }
+
+ /**
+ * Create a predicated expression. A predicated expression is a normal expression with a
+ * predicate attached to it, for example:
+ * {{{
+ * a + 1 IS NULL
+ * }}}
+ */
+ override def visitPredicated(ctx: PredicatedContext): Expression = withOrigin(ctx) {
+ val e = expression(ctx.valueExpression)
+ if (ctx.predicate != null) {
+ withPredicate(e, ctx.predicate)
+ } else {
+ e
+ }
+ }
+
+ /**
+ * Add a predicate to the given expression. Supported expressions are:
+ * - (NOT) BETWEEN
+ * - (NOT) IN
+ * - (NOT) LIKE
+ * - (NOT) RLIKE
+ * - IS (NOT) NULL.
+ */
+ private def withPredicate(e: Expression, ctx: PredicateContext): Expression = withOrigin(ctx) {
+ // Invert a predicate if it has a valid NOT clause.
+ def invertIfNotDefined(e: Expression): Expression = ctx.NOT match {
+ case null => e
+ case not => Not(e)
+ }
+
+ // Create the predicate.
+ ctx.kind.getType match {
+ case SqlBaseParser.BETWEEN =>
+ // BETWEEN is translated to lower <= e && e <= upper
+ invertIfNotDefined(And(
+ GreaterThanOrEqual(e, expression(ctx.lower)),
+ LessThanOrEqual(e, expression(ctx.upper))))
+ case SqlBaseParser.IN if ctx.query != null =>
+ throw new ParseException("IN with a Sub-query is currently not supported.", ctx)
+ case SqlBaseParser.IN =>
+ invertIfNotDefined(In(e, ctx.expression.asScala.map(expression)))
+ case SqlBaseParser.LIKE =>
+ invertIfNotDefined(Like(e, expression(ctx.pattern)))
+ case SqlBaseParser.RLIKE =>
+ invertIfNotDefined(RLike(e, expression(ctx.pattern)))
+ case SqlBaseParser.NULL if ctx.NOT != null =>
+ IsNotNull(e)
+ case SqlBaseParser.NULL =>
+ IsNull(e)
+ }
+ }
+
+ /**
+ * Create a binary arithmetic expression. The following arithmetic operators are supported:
+ * - Multiplication: '*'
+ * - Division: '/'
+ * - Hive Long Division: 'DIV'
+ * - Modulo: '%'
+ * - Addition: '+'
+ * - Subtraction: '-'
+ * - Binary AND: '&'
+ * - Binary XOR
+ * - Binary OR: '|'
+ */
+ override def visitArithmeticBinary(ctx: ArithmeticBinaryContext): Expression = withOrigin(ctx) {
+ val left = expression(ctx.left)
+ val right = expression(ctx.right)
+ ctx.operator.getType match {
+ case SqlBaseParser.ASTERISK =>
+ Multiply(left, right)
+ case SqlBaseParser.SLASH =>
+ Divide(left, right)
+ case SqlBaseParser.PERCENT =>
+ Remainder(left, right)
+ case SqlBaseParser.DIV =>
+ Cast(Divide(left, right), LongType)
+ case SqlBaseParser.PLUS =>
+ Add(left, right)
+ case SqlBaseParser.MINUS =>
+ Subtract(left, right)
+ case SqlBaseParser.AMPERSAND =>
+ BitwiseAnd(left, right)
+ case SqlBaseParser.HAT =>
+ BitwiseXor(left, right)
+ case SqlBaseParser.PIPE =>
+ BitwiseOr(left, right)
+ }
+ }
+
+ /**
+ * Create a unary arithmetic expression. The following arithmetic operators are supported:
+ * - Plus: '+'
+ * - Minus: '-'
+ * - Bitwise Not: '~'
+ */
+ override def visitArithmeticUnary(ctx: ArithmeticUnaryContext): Expression = withOrigin(ctx) {
+ val value = expression(ctx.valueExpression)
+ ctx.operator.getType match {
+ case SqlBaseParser.PLUS =>
+ value
+ case SqlBaseParser.MINUS =>
+ UnaryMinus(value)
+ case SqlBaseParser.TILDE =>
+ BitwiseNot(value)
+ }
+ }
+
+ /**
+ * Create a [[Cast]] expression.
+ */
+ override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) {
+ Cast(expression(ctx.expression), typedVisit(ctx.dataType))
+ }
+
+ /**
+ * Create a (windowed) Function expression.
+ */
+ override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) {
+ // Create the function call.
+ val name = ctx.qualifiedName.getText
+ val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null)
+ val arguments = ctx.expression().asScala.map(expression) match {
+ case Seq(UnresolvedStar(None)) if name.toLowerCase == "count" && !isDistinct =>
+ // Transform COUNT(*) into COUNT(1). Move this to analysis?
+ Seq(Literal(1))
+ case expressions =>
+ expressions
+ }
+ val function = UnresolvedFunction(name, arguments, isDistinct)
+
+ // Check if the function is evaluated in a windowed context.
+ ctx.windowSpec match {
+ case spec: WindowRefContext =>
+ UnresolvedWindowExpression(function, visitWindowRef(spec))
+ case spec: WindowDefContext =>
+ WindowExpression(function, visitWindowDef(spec))
+ case _ => function
+ }
+ }
+
+ /**
+ * Create a reference to a window frame, i.e. [[WindowSpecReference]].
+ */
+ override def visitWindowRef(ctx: WindowRefContext): WindowSpecReference = withOrigin(ctx) {
+ WindowSpecReference(ctx.identifier.getText)
+ }
+
+ /**
+ * Create a window definition, i.e. [[WindowSpecDefinition]].
+ */
+ override def visitWindowDef(ctx: WindowDefContext): WindowSpecDefinition = withOrigin(ctx) {
+ // CLUSTER BY ... | PARTITION BY ... ORDER BY ...
+ val partition = ctx.partition.asScala.map(expression)
+ val order = ctx.sortItem.asScala.map(visitSortItem)
+
+ // RANGE/ROWS BETWEEN ...
+ val frameSpecOption = Option(ctx.windowFrame).map { frame =>
+ val frameType = frame.frameType.getType match {
+ case SqlBaseParser.RANGE => RangeFrame
+ case SqlBaseParser.ROWS => RowFrame
+ }
+
+ SpecifiedWindowFrame(
+ frameType,
+ visitFrameBound(frame.start),
+ Option(frame.end).map(visitFrameBound).getOrElse(CurrentRow))
+ }
+
+ WindowSpecDefinition(
+ partition,
+ order,
+ frameSpecOption.getOrElse(UnspecifiedFrame))
+ }
+
+ /**
+ * Create or resolve a [[FrameBoundary]]. Simple math expressions are allowed for Value
+ * Preceding/Following boundaries. These expressions must be constant (foldable) and return an
+ * integer value.
+ */
+ override def visitFrameBound(ctx: FrameBoundContext): FrameBoundary = withOrigin(ctx) {
+ // We currently only allow foldable integers.
+ def value: Int = {
+ val e = expression(ctx.expression)
+ assert(e.resolved && e.foldable && e.dataType == IntegerType,
+ "Frame bound value must be a constant integer.",
+ ctx)
+ e.eval().asInstanceOf[Int]
+ }
+
+ // Create the FrameBoundary
+ ctx.boundType.getType match {
+ case SqlBaseParser.PRECEDING if ctx.UNBOUNDED != null =>
+ UnboundedPreceding
+ case SqlBaseParser.PRECEDING =>
+ ValuePreceding(value)
+ case SqlBaseParser.CURRENT =>
+ CurrentRow
+ case SqlBaseParser.FOLLOWING if ctx.UNBOUNDED != null =>
+ UnboundedFollowing
+ case SqlBaseParser.FOLLOWING =>
+ ValueFollowing(value)
+ }
+ }
+
+ /**
+ * Create a [[CreateStruct]] expression.
+ */
+ override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) {
+ CreateStruct(ctx.expression.asScala.map(expression))
+ }
+
+ /**
+ * Create a [[ScalarSubquery]] expression.
+ */
+ override def visitSubqueryExpression(
+ ctx: SubqueryExpressionContext): Expression = withOrigin(ctx) {
+ ScalarSubquery(plan(ctx.query))
+ }
+
+ /**
+ * Create a value based [[CaseWhen]] expression. This has the following SQL form:
+ * {{{
+ * CASE [expression]
+ * WHEN [value] THEN [expression]
+ * ...
+ * ELSE [expression]
+ * END
+ * }}}
+ */
+ override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) {
+ val e = expression(ctx.valueExpression)
+ val branches = ctx.whenClause.asScala.map { wCtx =>
+ (EqualTo(e, expression(wCtx.condition)), expression(wCtx.result))
+ }
+ CaseWhen(branches, Option(ctx.elseExpression).map(expression))
+ }
+
+ /**
+ * Create a condition based [[CaseWhen]] expression. This has the following SQL syntax:
+ * {{{
+ * CASE
+ * WHEN [predicate] THEN [expression]
+ * ...
+ * ELSE [expression]
+ * END
+ * }}}
+ *
+ * @param ctx the parse tree
+ * */
+ override def visitSearchedCase(ctx: SearchedCaseContext): Expression = withOrigin(ctx) {
+ val branches = ctx.whenClause.asScala.map { wCtx =>
+ (expression(wCtx.condition), expression(wCtx.result))
+ }
+ CaseWhen(branches, Option(ctx.elseExpression).map(expression))
+ }
+
+ /**
+ * Create a dereference expression. The return type depends on the type of the parent, this can
+ * either be a [[UnresolvedAttribute]] (if the parent is an [[UnresolvedAttribute]]), or an
+ * [[UnresolvedExtractValue]] if the parent is some expression.
+ */
+ override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) {
+ val attr = ctx.fieldName.getText
+ expression(ctx.base) match {
+ case UnresolvedAttribute(nameParts) =>
+ UnresolvedAttribute(nameParts :+ attr)
+ case e =>
+ UnresolvedExtractValue(e, Literal(attr))
+ }
+ }
+
+ /**
+ * Create an [[UnresolvedAttribute]] expression.
+ */
+ override def visitColumnReference(ctx: ColumnReferenceContext): Expression = withOrigin(ctx) {
+ UnresolvedAttribute.quoted(ctx.getText)
+ }
+
+ /**
+ * Create an [[UnresolvedExtractValue]] expression, this is used for subscript access to an array.
+ */
+ override def visitSubscript(ctx: SubscriptContext): Expression = withOrigin(ctx) {
+ UnresolvedExtractValue(expression(ctx.value), expression(ctx.index))
+ }
+
+ /**
+ * Create an expression for an expression between parentheses. This is need because the ANTLR
+ * visitor cannot automatically convert the nested context into an expression.
+ */
+ override def visitParenthesizedExpression(
+ ctx: ParenthesizedExpressionContext): Expression = withOrigin(ctx) {
+ expression(ctx.expression)
+ }
+
+ /**
+ * Create a [[SortOrder]] expression.
+ */
+ override def visitSortItem(ctx: SortItemContext): SortOrder = withOrigin(ctx) {
+ if (ctx.DESC != null) {
+ SortOrder(expression(ctx.expression), Descending)
+ } else {
+ SortOrder(expression(ctx.expression), Ascending)
+ }
+ }
+
+ /**
+ * Create a typed Literal expression. A typed literal has the following SQL syntax:
+ * {{{
+ * [TYPE] '[VALUE]'
+ * }}}
+ * Currently Date and Timestamp typed literals are supported.
+ *
+ * TODO what the added value of this over casting?
+ */
+ override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) {
+ val value = string(ctx.STRING)
+ ctx.identifier.getText.toUpperCase match {
+ case "DATE" =>
+ Literal(Date.valueOf(value))
+ case "TIMESTAMP" =>
+ Literal(Timestamp.valueOf(value))
+ case other =>
+ throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx)
+ }
+ }
+
+ /**
+ * Create a NULL literal expression.
+ */
+ override def visitNullLiteral(ctx: NullLiteralContext): Literal = withOrigin(ctx) {
+ Literal(null)
+ }
+
+ /**
+ * Create a Boolean literal expression.
+ */
+ override def visitBooleanLiteral(ctx: BooleanLiteralContext): Literal = withOrigin(ctx) {
+ if (ctx.getText.toBoolean) {
+ Literal.TrueLiteral
+ } else {
+ Literal.FalseLiteral
+ }
+ }
+
+ /**
+ * Create an integral literal expression. The code selects the most narrow integral type
+ * possible, either a BigDecimal, a Long or an Integer is returned.
+ */
+ override def visitIntegerLiteral(ctx: IntegerLiteralContext): Literal = withOrigin(ctx) {
+ BigDecimal(ctx.getText) match {
+ case v if v.isValidInt =>
+ Literal(v.intValue())
+ case v if v.isValidLong =>
+ Literal(v.longValue())
+ case v => Literal(v.underlying())
+ }
+ }
+
+ /**
+ * Create a double literal for a number denoted in scientific notation.
+ */
+ override def visitScientificDecimalLiteral(
+ ctx: ScientificDecimalLiteralContext): Literal = withOrigin(ctx) {
+ Literal(ctx.getText.toDouble)
+ }
+
+ /**
+ * Create a decimal literal for a regular decimal number.
+ */
+ override def visitDecimalLiteral(ctx: DecimalLiteralContext): Literal = withOrigin(ctx) {
+ Literal(BigDecimal(ctx.getText).underlying())
+ }
+
+ /** Create a numeric literal expression. */
+ private def numericLiteral(ctx: NumberContext)(f: String => Any): Literal = withOrigin(ctx) {
+ val raw = ctx.getText
+ try {
+ Literal(f(raw.substring(0, raw.length - 1)))
+ } catch {
+ case e: NumberFormatException =>
+ throw new ParseException(e.getMessage, ctx)
+ }
+ }
+
+ /**
+ * Create a Byte Literal expression.
+ */
+ override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = numericLiteral(ctx) {
+ _.toByte
+ }
+
+ /**
+ * Create a Short Literal expression.
+ */
+ override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = numericLiteral(ctx) {
+ _.toShort
+ }
+
+ /**
+ * Create a Long Literal expression.
+ */
+ override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = numericLiteral(ctx) {
+ _.toLong
+ }
+
+ /**
+ * Create a Double Literal expression.
+ */
+ override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = numericLiteral(ctx) {
+ _.toDouble
+ }
+
+ /**
+ * Create a String literal expression.
+ */
+ override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) {
+ Literal(createString(ctx))
+ }
+
+ /**
+ * Create a String from a string literal context. This supports multiple consecutive string
+ * literals, these are concatenated, for example this expression "'hello' 'world'" will be
+ * converted into "helloworld".
+ *
+ * Special characters can be escaped by using Hive/C-style escaping.
+ */
+ private def createString(ctx: StringLiteralContext): String = {
+ ctx.STRING().asScala.map(string).mkString
+ }
+
+ /**
+ * Create a [[CalendarInterval]] literal expression. An interval expression can contain multiple
+ * unit value pairs, for instance: interval 2 months 2 days.
+ */
+ override def visitInterval(ctx: IntervalContext): Literal = withOrigin(ctx) {
+ val intervals = ctx.intervalField.asScala.map(visitIntervalField)
+ assert(intervals.nonEmpty, "at least one time unit should be given for interval literal", ctx)
+ Literal(intervals.reduce(_.add(_)))
+ }
+
+ /**
+ * Create a [[CalendarInterval]] for a unit value pair. Two unit configuration types are
+ * supported:
+ * - Single unit.
+ * - From-To unit (only 'YEAR TO MONTH' and 'DAY TO SECOND' are supported).
+ */
+ override def visitIntervalField(ctx: IntervalFieldContext): CalendarInterval = withOrigin(ctx) {
+ import ctx._
+ val s = value.getText
+ try {
+ val interval = (unit.getText.toLowerCase, Option(to).map(_.getText.toLowerCase)) match {
+ case (u, None) if u.endsWith("s") =>
+ // Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/...
+ CalendarInterval.fromSingleUnitString(u.substring(0, u.length - 1), s)
+ case (u, None) =>
+ CalendarInterval.fromSingleUnitString(u, s)
+ case ("year", Some("month")) =>
+ CalendarInterval.fromYearMonthString(s)
+ case ("day", Some("second")) =>
+ CalendarInterval.fromDayTimeString(s)
+ case (from, Some(t)) =>
+ throw new ParseException(s"Intervals FROM $from TO $t are not supported.", ctx)
+ }
+ assert(interval != null, "No interval can be constructed", ctx)
+ interval
+ } catch {
+ // Handle Exceptions thrown by CalendarInterval
+ case e: IllegalArgumentException =>
+ val pe = new ParseException(e.getMessage, ctx)
+ pe.setStackTrace(e.getStackTrace)
+ throw pe
+ }
+ }
+
+ /* ********************************************************************************************
+ * DataType parsing
+ * ******************************************************************************************** */
+ /**
+ * Resolve/create a primitive type.
+ */
+ override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) {
+ (ctx.identifier.getText.toLowerCase, ctx.INTEGER_VALUE().asScala.toList) match {
+ case ("boolean", Nil) => BooleanType
+ case ("tinyint" | "byte", Nil) => ByteType
+ case ("smallint" | "short", Nil) => ShortType
+ case ("int" | "integer", Nil) => IntegerType
+ case ("bigint" | "long", Nil) => LongType
+ case ("float", Nil) => FloatType
+ case ("double", Nil) => DoubleType
+ case ("date", Nil) => DateType
+ case ("timestamp", Nil) => TimestampType
+ case ("char" | "varchar" | "string", Nil) => StringType
+ case ("char" | "varchar", _ :: Nil) => StringType
+ case ("binary", Nil) => BinaryType
+ case ("decimal", Nil) => DecimalType.USER_DEFAULT
+ case ("decimal", precision :: Nil) => DecimalType(precision.getText.toInt, 0)
+ case ("decimal", precision :: scale :: Nil) =>
+ DecimalType(precision.getText.toInt, scale.getText.toInt)
+ case (dt, params) =>
+ throw new ParseException(
+ s"DataType $dt${params.mkString("(", ",", ")")} is not supported.", ctx)
+ }
+ }
+
+ /**
+ * Create a complex DataType. Arrays, Maps and Structures are supported.
+ */
+ override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = withOrigin(ctx) {
+ ctx.complex.getType match {
+ case SqlBaseParser.ARRAY =>
+ ArrayType(typedVisit(ctx.dataType(0)))
+ case SqlBaseParser.MAP =>
+ MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1)))
+ case SqlBaseParser.STRUCT =>
+ createStructType(ctx.colTypeList())
+ }
+ }
+
+ /**
+ * Create a [[StructType]] from a sequence of [[StructField]]s.
+ */
+ protected def createStructType(ctx: ColTypeListContext): StructType = {
+ StructType(Option(ctx).toSeq.flatMap(visitColTypeList))
+ }
+
+ /**
+ * Create a [[StructType]] from a number of column definitions.
+ */
+ override def visitColTypeList(ctx: ColTypeListContext): Seq[StructField] = withOrigin(ctx) {
+ ctx.colType().asScala.map(visitColType)
+ }
+
+ /**
+ * Create a [[StructField]] from a column definition.
+ */
+ override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) {
+ import ctx._
+
+ // Add the comment to the metadata.
+ val builder = new MetadataBuilder
+ if (STRING != null) {
+ builder.putString("comment", string(STRING))
+ }
+
+ StructField(identifier.getText, typedVisit(dataType), nullable = true, builder.build())
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala
deleted file mode 100644
index c188c5b108..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala
+++ /dev/null
@@ -1,933 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.parser
-
-import java.sql.Date
-
-import scala.collection.mutable.ArrayBuffer
-import scala.util.matching.Regex
-
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.Count
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.CalendarInterval
-import org.apache.spark.util.random.RandomSampler
-
-
-/**
- * This class translates SQL to Catalyst [[LogicalPlan]]s or [[Expression]]s.
- */
-private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) extends ParserInterface {
- import ParserUtils._
-
- /**
- * The safeParse method allows a user to focus on the parsing/AST transformation logic. This
- * method will take care of possible errors during the parsing process.
- */
- protected def safeParse[T](sql: String, ast: ASTNode)(toResult: ASTNode => T): T = {
- try {
- toResult(ast)
- } catch {
- case e: MatchError => throw e
- case e: AnalysisException => throw e
- case e: Exception =>
- throw new AnalysisException(e.getMessage)
- case e: NotImplementedError =>
- throw new AnalysisException(
- s"""Unsupported language features in query
- |== SQL ==
- |$sql
- |== AST ==
- |${ast.treeString}
- |== Error ==
- |$e
- |== Stacktrace ==
- |${e.getStackTrace.head}
- """.stripMargin)
- }
- }
-
- /** Creates LogicalPlan for a given SQL string. */
- def parsePlan(sql: String): LogicalPlan =
- safeParse(sql, ParseDriver.parsePlan(sql, conf))(nodeToPlan)
-
- /** Creates Expression for a given SQL string. */
- def parseExpression(sql: String): Expression =
- safeParse(sql, ParseDriver.parseExpression(sql, conf))(selExprNodeToExpr(_).get)
-
- /** Creates TableIdentifier for a given SQL string. */
- def parseTableIdentifier(sql: String): TableIdentifier =
- safeParse(sql, ParseDriver.parseTableName(sql, conf))(extractTableIdent)
-
- /**
- * SELECT MAX(value) FROM src GROUP BY k1, k2, k3 GROUPING SETS((k1, k2), (k2))
- * is equivalent to
- * SELECT MAX(value) FROM src GROUP BY k1, k2 UNION SELECT MAX(value) FROM src GROUP BY k2
- * Check the following link for details.
- *
-https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C+Grouping+and+Rollup
- *
- * The bitmask denotes the grouping expressions validity for a grouping set,
- * the bitmask also be called as grouping id (`GROUPING__ID`, the virtual column in Hive)
- * e.g. In superset (k1, k2, k3), (bit 2: k1, bit 1: k2, and bit 0: k3), the grouping id of
- * GROUPING SETS (k1, k2) and (k2) should be 1 and 5 respectively.
- */
- protected def extractGroupingSet(children: Seq[ASTNode]): (Seq[Expression], Seq[Int]) = {
- val (keyASTs, setASTs) = children.partition {
- case Token("TOK_GROUPING_SETS_EXPRESSION", _) => false // grouping sets
- case _ => true // grouping keys
- }
-
- val keys = keyASTs.map(nodeToExpr)
- val keyMap = keyASTs.zipWithIndex.toMap
-
- val mask = (1 << keys.length) - 1
- val bitmasks: Seq[Int] = setASTs.map {
- case Token("TOK_GROUPING_SETS_EXPRESSION", columns) =>
- columns.foldLeft(mask)((bitmap, col) => {
- val keyIndex = keyMap.find(_._1.treeEquals(col)).map(_._2).getOrElse(
- throw new AnalysisException(s"${col.treeString} doesn't show up in the GROUP BY list"))
- // 0 means that the column at the given index is a grouping column, 1 means it is not,
- // so we unset the bit in bitmap.
- bitmap & ~(1 << (keys.length - 1 - keyIndex))
- })
- case _ => sys.error("Expect GROUPING SETS clause")
- }
-
- (keys, bitmasks)
- }
-
- protected def nodeToPlan(node: ASTNode): LogicalPlan = node match {
- case Token("TOK_SHOWFUNCTIONS", args) =>
- // Skip LIKE.
- val pattern = args match {
- case like :: nodes if like.text.toUpperCase == "LIKE" => nodes
- case nodes => nodes
- }
-
- // Extract Database and Function name
- pattern match {
- case Nil =>
- ShowFunctions(None, None)
- case Token(name, Nil) :: Nil =>
- ShowFunctions(None, Some(unquoteString(cleanIdentifier(name))))
- case Token(db, Nil) :: Token(name, Nil) :: Nil =>
- ShowFunctions(Some(unquoteString(cleanIdentifier(db))),
- Some(unquoteString(cleanIdentifier(name))))
- case _ =>
- noParseRule("SHOW FUNCTIONS", node)
- }
-
- case Token("TOK_DESCFUNCTION", Token(functionName, Nil) :: isExtended) =>
- DescribeFunction(cleanIdentifier(functionName), isExtended.nonEmpty)
-
- case Token("TOK_QUERY", queryArgs @ Token("TOK_CTE" | "TOK_FROM" | "TOK_INSERT", _) :: _) =>
- val (fromClause: Option[ASTNode], insertClauses, cteRelations) =
- queryArgs match {
- case Token("TOK_CTE", ctes) :: Token("TOK_FROM", from) :: inserts =>
- val cteRelations = ctes.map { node =>
- val relation = nodeToRelation(node).asInstanceOf[SubqueryAlias]
- relation.alias -> relation
- }
- (Some(from.head), inserts, Some(cteRelations.toMap))
- case Token("TOK_FROM", from) :: inserts =>
- (Some(from.head), inserts, None)
- case Token("TOK_INSERT", _) :: Nil =>
- (None, queryArgs, None)
- }
-
- // Return one query for each insert clause.
- val queries = insertClauses.map {
- case Token("TOK_INSERT", singleInsert) =>
- val (
- intoClause ::
- destClause ::
- selectClause ::
- selectDistinctClause ::
- whereClause ::
- groupByClause ::
- rollupGroupByClause ::
- cubeGroupByClause ::
- groupingSetsClause ::
- orderByClause ::
- havingClause ::
- sortByClause ::
- clusterByClause ::
- distributeByClause ::
- limitClause ::
- lateralViewClause ::
- windowClause :: Nil) = {
- getClauses(
- Seq(
- "TOK_INSERT_INTO",
- "TOK_DESTINATION",
- "TOK_SELECT",
- "TOK_SELECTDI",
- "TOK_WHERE",
- "TOK_GROUPBY",
- "TOK_ROLLUP_GROUPBY",
- "TOK_CUBE_GROUPBY",
- "TOK_GROUPING_SETS",
- "TOK_ORDERBY",
- "TOK_HAVING",
- "TOK_SORTBY",
- "TOK_CLUSTERBY",
- "TOK_DISTRIBUTEBY",
- "TOK_LIMIT",
- "TOK_LATERAL_VIEW",
- "WINDOW"),
- singleInsert)
- }
-
- val relations = fromClause match {
- case Some(f) => nodeToRelation(f)
- case None => OneRowRelation
- }
-
- val withLateralView = lateralViewClause.map { lv =>
- nodeToGenerate(lv.children.head, outer = false, relations)
- }.getOrElse(relations)
-
- val withWhere = whereClause.map { whereNode =>
- val Seq(whereExpr) = whereNode.children
- Filter(nodeToExpr(whereExpr), withLateralView)
- }.getOrElse(withLateralView)
-
- val select = (selectClause orElse selectDistinctClause)
- .getOrElse(sys.error("No select clause."))
-
- val transformation = nodeToTransformation(select.children.head, withWhere)
-
- // The projection of the query can either be a normal projection, an aggregation
- // (if there is a group by) or a script transformation.
- val withProject: LogicalPlan = transformation.getOrElse {
- val selectExpressions =
- select.children.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_))
- Seq(
- groupByClause.map(e => e match {
- case Token("TOK_GROUPBY", children) =>
- // Not a transformation so must be either project or aggregation.
- Aggregate(children.map(nodeToExpr), selectExpressions, withWhere)
- case _ => sys.error("Expect GROUP BY")
- }),
- groupingSetsClause.map(e => e match {
- case Token("TOK_GROUPING_SETS", children) =>
- val(groupByExprs, masks) = extractGroupingSet(children)
- GroupingSets(masks, groupByExprs, withWhere, selectExpressions)
- case _ => sys.error("Expect GROUPING SETS")
- }),
- rollupGroupByClause.map(e => e match {
- case Token("TOK_ROLLUP_GROUPBY", children) =>
- Aggregate(
- Seq(Rollup(children.map(nodeToExpr))),
- selectExpressions,
- withWhere)
- case _ => sys.error("Expect WITH ROLLUP")
- }),
- cubeGroupByClause.map(e => e match {
- case Token("TOK_CUBE_GROUPBY", children) =>
- Aggregate(
- Seq(Cube(children.map(nodeToExpr))),
- selectExpressions,
- withWhere)
- case _ => sys.error("Expect WITH CUBE")
- }),
- Some(Project(selectExpressions, withWhere))).flatten.head
- }
-
- // Handle HAVING clause.
- val withHaving = havingClause.map { h =>
- val havingExpr = h.children match { case Seq(hexpr) => nodeToExpr(hexpr) }
- // Note that we added a cast to boolean. If the expression itself is already boolean,
- // the optimizer will get rid of the unnecessary cast.
- Filter(Cast(havingExpr, BooleanType), withProject)
- }.getOrElse(withProject)
-
- // Handle SELECT DISTINCT
- val withDistinct =
- if (selectDistinctClause.isDefined) Distinct(withHaving) else withHaving
-
- // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause.
- val withSort =
- (orderByClause, sortByClause, distributeByClause, clusterByClause) match {
- case (Some(totalOrdering), None, None, None) =>
- Sort(totalOrdering.children.map(nodeToSortOrder), global = true, withDistinct)
- case (None, Some(perPartitionOrdering), None, None) =>
- Sort(
- perPartitionOrdering.children.map(nodeToSortOrder),
- global = false, withDistinct)
- case (None, None, Some(partitionExprs), None) =>
- RepartitionByExpression(
- partitionExprs.children.map(nodeToExpr), withDistinct)
- case (None, Some(perPartitionOrdering), Some(partitionExprs), None) =>
- Sort(
- perPartitionOrdering.children.map(nodeToSortOrder), global = false,
- RepartitionByExpression(
- partitionExprs.children.map(nodeToExpr),
- withDistinct))
- case (None, None, None, Some(clusterExprs)) =>
- Sort(
- clusterExprs.children.map(nodeToExpr).map(SortOrder(_, Ascending)),
- global = false,
- RepartitionByExpression(
- clusterExprs.children.map(nodeToExpr),
- withDistinct))
- case (None, None, None, None) => withDistinct
- case _ => sys.error("Unsupported set of ordering / distribution clauses.")
- }
-
- val withLimit =
- limitClause.map(l => nodeToExpr(l.children.head))
- .map(Limit(_, withSort))
- .getOrElse(withSort)
-
- // Collect all window specifications defined in the WINDOW clause.
- val windowDefinitions = windowClause.map(_.children.collect {
- case Token("TOK_WINDOWDEF",
- Token(windowName, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) =>
- windowName -> nodesToWindowSpecification(spec)
- }.toMap)
- // Handle cases like
- // window w1 as (partition by p_mfgr order by p_name
- // range between 2 preceding and 2 following),
- // w2 as w1
- val resolvedCrossReference = windowDefinitions.map {
- windowDefMap => windowDefMap.map {
- case (windowName, WindowSpecReference(other)) =>
- (windowName, windowDefMap(other).asInstanceOf[WindowSpecDefinition])
- case o => o.asInstanceOf[(String, WindowSpecDefinition)]
- }
- }
-
- val withWindowDefinitions =
- resolvedCrossReference.map(WithWindowDefinition(_, withLimit)).getOrElse(withLimit)
-
- // TOK_INSERT_INTO means to add files to the table.
- // TOK_DESTINATION means to overwrite the table.
- val resultDestination =
- (intoClause orElse destClause).getOrElse(sys.error("No destination found."))
- val overwrite = intoClause.isEmpty
- nodeToDest(
- resultDestination,
- withWindowDefinitions,
- overwrite)
- }
-
- // If there are multiple INSERTS just UNION them together into one query.
- val query = if (queries.length == 1) queries.head else Union(queries)
-
- // return With plan if there is CTE
- cteRelations.map(With(query, _)).getOrElse(query)
-
- case Token("TOK_UNIONALL", left :: right :: Nil) =>
- Union(nodeToPlan(left), nodeToPlan(right))
- case Token("TOK_UNIONDISTINCT", left :: right :: Nil) =>
- Distinct(Union(nodeToPlan(left), nodeToPlan(right)))
- case Token("TOK_EXCEPT", left :: right :: Nil) =>
- Except(nodeToPlan(left), nodeToPlan(right))
- case Token("TOK_INTERSECT", left :: right :: Nil) =>
- Intersect(nodeToPlan(left), nodeToPlan(right))
-
- case _ =>
- noParseRule("Plan", node)
- }
-
- val allJoinTokens = "(TOK_.*JOIN)".r
- val laterViewToken = "TOK_LATERAL_VIEW(.*)".r
- protected def nodeToRelation(node: ASTNode): LogicalPlan = {
- node match {
- case Token("TOK_SUBQUERY", query :: Token(alias, Nil) :: Nil) =>
- SubqueryAlias(cleanIdentifier(alias), nodeToPlan(query))
-
- case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) =>
- nodeToGenerate(
- selectClause,
- outer = isOuter.nonEmpty,
- nodeToRelation(relationClause))
-
- /* All relations, possibly with aliases or sampling clauses. */
- case Token("TOK_TABREF", clauses) =>
- // If the last clause is not a token then it's the alias of the table.
- val (nonAliasClauses, aliasClause) =
- if (clauses.last.text.startsWith("TOK")) {
- (clauses, None)
- } else {
- (clauses.dropRight(1), Some(clauses.last))
- }
-
- val (Some(tableNameParts) ::
- splitSampleClause ::
- bucketSampleClause :: Nil) = {
- getClauses(Seq("TOK_TABNAME", "TOK_TABLESPLITSAMPLE", "TOK_TABLEBUCKETSAMPLE"),
- nonAliasClauses)
- }
-
- val tableIdent = extractTableIdent(tableNameParts)
- val alias = aliasClause.map { case Token(a, Nil) => cleanIdentifier(a) }
- val relation = UnresolvedRelation(tableIdent, alias)
-
- // Apply sampling if requested.
- (bucketSampleClause orElse splitSampleClause).map {
- case Token("TOK_TABLESPLITSAMPLE",
- Token("TOK_ROWCOUNT", Nil) :: Token(count, Nil) :: Nil) =>
- Limit(Literal(count.toInt), relation)
- case Token("TOK_TABLESPLITSAMPLE",
- Token("TOK_PERCENT", Nil) :: Token(fraction, Nil) :: Nil) =>
- // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling
- // function takes X PERCENT as the input and the range of X is [0, 100], we need to
- // adjust the fraction.
- require(
- fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon)
- && fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon),
- s"Sampling fraction ($fraction) must be on interval [0, 100]")
- Sample(0.0, fraction.toDouble / 100, withReplacement = false,
- (math.random * 1000).toInt,
- relation)(
- isTableSample = true)
- case Token("TOK_TABLEBUCKETSAMPLE",
- Token(numerator, Nil) ::
- Token(denominator, Nil) :: Nil) =>
- val fraction = numerator.toDouble / denominator.toDouble
- Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation)(
- isTableSample = true)
- case a =>
- noParseRule("Sampling", a)
- }.getOrElse(relation)
-
- case Token(allJoinTokens(joinToken), relation1 :: relation2 :: other) =>
- if (!(other.size <= 1)) {
- sys.error(s"Unsupported join operation: $other")
- }
-
- val (joinType, joinCondition) = getJoinInfo(joinToken, other, node)
-
- Join(nodeToRelation(relation1),
- nodeToRelation(relation2),
- joinType,
- joinCondition)
- case _ =>
- noParseRule("Relation", node)
- }
- }
-
- protected def getJoinInfo(
- joinToken: String,
- joinConditionToken: Seq[ASTNode],
- node: ASTNode): (JoinType, Option[Expression]) = {
- val joinType = joinToken match {
- case "TOK_JOIN" => Inner
- case "TOK_CROSSJOIN" => Inner
- case "TOK_RIGHTOUTERJOIN" => RightOuter
- case "TOK_LEFTOUTERJOIN" => LeftOuter
- case "TOK_FULLOUTERJOIN" => FullOuter
- case "TOK_LEFTSEMIJOIN" => LeftSemi
- case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node)
- case "TOK_ANTIJOIN" => noParseRule("Anti Join", node)
- case "TOK_NATURALJOIN" => NaturalJoin(Inner)
- case "TOK_NATURALRIGHTOUTERJOIN" => NaturalJoin(RightOuter)
- case "TOK_NATURALLEFTOUTERJOIN" => NaturalJoin(LeftOuter)
- case "TOK_NATURALFULLOUTERJOIN" => NaturalJoin(FullOuter)
- }
-
- joinConditionToken match {
- case Token("TOK_USING", columnList :: Nil) :: Nil =>
- val colNames = columnList.children.collect {
- case Token(name, Nil) => UnresolvedAttribute(name)
- }
- (UsingJoin(joinType, colNames), None)
- /* Join expression specified using ON clause */
- case _ => (joinType, joinConditionToken.headOption.map(nodeToExpr))
- }
- }
-
- protected def nodeToSortOrder(node: ASTNode): SortOrder = node match {
- case Token("TOK_TABSORTCOLNAMEASC", sortExpr :: Nil) =>
- SortOrder(nodeToExpr(sortExpr), Ascending)
- case Token("TOK_TABSORTCOLNAMEDESC", sortExpr :: Nil) =>
- SortOrder(nodeToExpr(sortExpr), Descending)
- case _ =>
- noParseRule("SortOrder", node)
- }
-
- val destinationToken = "TOK_DESTINATION|TOK_INSERT_INTO".r
- protected def nodeToDest(
- node: ASTNode,
- query: LogicalPlan,
- overwrite: Boolean): LogicalPlan = node match {
- case Token(destinationToken(),
- Token("TOK_DIR",
- Token("TOK_TMP_FILE", Nil) :: Nil) :: Nil) =>
- query
-
- case Token(destinationToken(),
- Token("TOK_TAB",
- tableArgs) :: Nil) =>
- val Some(tableNameParts) :: partitionClause :: Nil =
- getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs)
-
- val tableIdent = extractTableIdent(tableNameParts)
-
- val partitionKeys = partitionClause.map(_.children.map {
- // Parse partitions. We also make keys case insensitive.
- case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) =>
- cleanIdentifier(key.toLowerCase) -> Some(unquoteString(value))
- case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) =>
- cleanIdentifier(key.toLowerCase) -> None
- }.toMap).getOrElse(Map.empty)
-
- InsertIntoTable(
- UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, ifNotExists = false)
-
- case Token(destinationToken(),
- Token("TOK_TAB",
- tableArgs) ::
- Token("TOK_IFNOTEXISTS",
- ifNotExists) :: Nil) =>
- val Some(tableNameParts) :: partitionClause :: Nil =
- getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs)
-
- val tableIdent = extractTableIdent(tableNameParts)
-
- val partitionKeys = partitionClause.map(_.children.map {
- // Parse partitions. We also make keys case insensitive.
- case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) =>
- cleanIdentifier(key.toLowerCase) -> Some(unquoteString(value))
- case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) =>
- cleanIdentifier(key.toLowerCase) -> None
- }.toMap).getOrElse(Map.empty)
-
- InsertIntoTable(
- UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, ifNotExists = true)
-
- case _ =>
- noParseRule("Destination", node)
- }
-
- protected def selExprNodeToExpr(node: ASTNode): Option[Expression] = node match {
- case Token("TOK_SELEXPR", e :: Nil) =>
- Some(nodeToExpr(e))
-
- case Token("TOK_SELEXPR", e :: Token(alias, Nil) :: Nil) =>
- Some(Alias(nodeToExpr(e), cleanIdentifier(alias))())
-
- case Token("TOK_SELEXPR", e :: aliasChildren) =>
- val aliasNames = aliasChildren.collect {
- case Token(name, Nil) => cleanIdentifier(name)
- }
- Some(MultiAlias(nodeToExpr(e), aliasNames))
-
- /* Hints are ignored */
- case Token("TOK_HINTLIST", _) => None
-
- case _ =>
- noParseRule("Select", node)
- }
-
- /**
- * Flattens the left deep tree with the specified pattern into a list.
- */
- private def flattenLeftDeepTree(node: ASTNode, pattern: Regex): Seq[ASTNode] = {
- val collected = ArrayBuffer[ASTNode]()
- var rest = node
- while (rest match {
- case Token(pattern(), l :: r :: Nil) =>
- collected += r
- rest = l
- true
- case _ => false
- }) {
- // do nothing
- }
- collected += rest
- // keep them in the same order as in SQL
- collected.reverse
- }
-
- /**
- * Creates a balanced tree that has similar number of nodes on left and right.
- *
- * This help to reduce the depth of the tree to prevent StackOverflow in analyzer/optimizer.
- */
- private def balancedTree(
- expr: Seq[Expression],
- f: (Expression, Expression) => Expression): Expression = expr.length match {
- case 1 => expr.head
- case 2 => f(expr.head, expr(1))
- case l => f(balancedTree(expr.slice(0, l / 2), f), balancedTree(expr.slice(l / 2, l), f))
- }
-
- protected def nodeToExpr(node: ASTNode): Expression = node match {
- /* Attribute References */
- case Token("TOK_TABLE_OR_COL", Token(name, Nil) :: Nil) =>
- UnresolvedAttribute.quoted(cleanIdentifier(name))
- case Token(".", qualifier :: Token(attr, Nil) :: Nil) =>
- nodeToExpr(qualifier) match {
- case UnresolvedAttribute(nameParts) =>
- UnresolvedAttribute(nameParts :+ cleanIdentifier(attr))
- case other => UnresolvedExtractValue(other, Literal(cleanIdentifier(attr)))
- }
- case Token("TOK_SUBQUERY_EXPR", Token("TOK_SUBQUERY_OP", Nil) :: subquery :: Nil) =>
- ScalarSubquery(nodeToPlan(subquery))
-
- /* Stars (*) */
- case Token("TOK_ALLCOLREF", Nil) => UnresolvedStar(None)
- // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only
- // has a single child which is tableName.
- case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", target) :: Nil) if target.nonEmpty =>
- UnresolvedStar(Some(target.map(x => cleanIdentifier(x.text))))
-
- /* Aggregate Functions */
- case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) =>
- Count(args.map(nodeToExpr)).toAggregateExpression(isDistinct = true)
- case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) =>
- Count(Literal(1)).toAggregateExpression()
-
- /* Casts */
- case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), StringType)
- case Token("TOK_FUNCTION", Token("TOK_VARCHAR", _) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), StringType)
- case Token("TOK_FUNCTION", Token("TOK_CHAR", _) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), StringType)
- case Token("TOK_FUNCTION", Token("TOK_INT", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), IntegerType)
- case Token("TOK_FUNCTION", Token("TOK_BIGINT", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), LongType)
- case Token("TOK_FUNCTION", Token("TOK_FLOAT", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), FloatType)
- case Token("TOK_FUNCTION", Token("TOK_DOUBLE", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), DoubleType)
- case Token("TOK_FUNCTION", Token("TOK_SMALLINT", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), ShortType)
- case Token("TOK_FUNCTION", Token("TOK_TINYINT", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), ByteType)
- case Token("TOK_FUNCTION", Token("TOK_BINARY", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), BinaryType)
- case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), BooleanType)
- case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: scale :: nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), DecimalType(precision.text.toInt, scale.text.toInt))
- case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), DecimalType(precision.text.toInt, 0))
- case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), DecimalType.USER_DEFAULT)
- case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), TimestampType)
- case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), DateType)
-
- /* Arithmetic */
- case Token("+", child :: Nil) => nodeToExpr(child)
- case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child))
- case Token("~", child :: Nil) => BitwiseNot(nodeToExpr(child))
- case Token("+", left :: right:: Nil) => Add(nodeToExpr(left), nodeToExpr(right))
- case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right))
- case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right))
- case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right))
- case Token(DIV(), left :: right:: Nil) =>
- Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType)
- case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right))
- case Token("&", left :: right:: Nil) => BitwiseAnd(nodeToExpr(left), nodeToExpr(right))
- case Token("|", left :: right:: Nil) => BitwiseOr(nodeToExpr(left), nodeToExpr(right))
- case Token("^", left :: right:: Nil) => BitwiseXor(nodeToExpr(left), nodeToExpr(right))
-
- /* Comparisons */
- case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right))
- case Token("==", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right))
- case Token("<=>", left :: right:: Nil) => EqualNullSafe(nodeToExpr(left), nodeToExpr(right))
- case Token("!=", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right)))
- case Token("<>", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right)))
- case Token(">", left :: right:: Nil) => GreaterThan(nodeToExpr(left), nodeToExpr(right))
- case Token(">=", left :: right:: Nil) => GreaterThanOrEqual(nodeToExpr(left), nodeToExpr(right))
- case Token("<", left :: right:: Nil) => LessThan(nodeToExpr(left), nodeToExpr(right))
- case Token("<=", left :: right:: Nil) => LessThanOrEqual(nodeToExpr(left), nodeToExpr(right))
- case Token(LIKE(), left :: right:: Nil) => Like(nodeToExpr(left), nodeToExpr(right))
- case Token(RLIKE(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right))
- case Token(REGEXP(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right))
- case Token("TOK_FUNCTION", Token("TOK_ISNOTNULL", Nil) :: child :: Nil) =>
- IsNotNull(nodeToExpr(child))
- case Token("TOK_FUNCTION", Token("TOK_ISNULL", Nil) :: child :: Nil) =>
- IsNull(nodeToExpr(child))
- case Token("TOK_FUNCTION", Token(IN(), Nil) :: value :: list) =>
- In(nodeToExpr(value), list.map(nodeToExpr))
- case Token("TOK_FUNCTION",
- Token(BETWEEN(), Nil) ::
- kw ::
- target ::
- minValue ::
- maxValue :: Nil) =>
-
- val targetExpression = nodeToExpr(target)
- val betweenExpr =
- And(
- GreaterThanOrEqual(targetExpression, nodeToExpr(minValue)),
- LessThanOrEqual(targetExpression, nodeToExpr(maxValue)))
- kw match {
- case Token("KW_FALSE", Nil) => betweenExpr
- case Token("KW_TRUE", Nil) => Not(betweenExpr)
- }
-
- /* Boolean Logic */
- case Token(AND(), left :: right:: Nil) =>
- balancedTree(flattenLeftDeepTree(node, AND).map(nodeToExpr), And)
- case Token(OR(), left :: right:: Nil) =>
- balancedTree(flattenLeftDeepTree(node, OR).map(nodeToExpr), Or)
- case Token(NOT(), child :: Nil) => Not(nodeToExpr(child))
- case Token("!", child :: Nil) => Not(nodeToExpr(child))
-
- /* Case statements */
- case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) =>
- CaseWhen.createFromParser(branches.map(nodeToExpr))
- case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) =>
- val keyExpr = nodeToExpr(branches.head)
- CaseKeyWhen(keyExpr, branches.drop(1).map(nodeToExpr))
-
- /* Complex datatype manipulation */
- case Token("[", child :: ordinal :: Nil) =>
- UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal))
-
- /* Window Functions */
- case Token(text, args :+ Token("TOK_WINDOWSPEC", spec)) =>
- val function = nodeToExpr(node.copy(children = node.children.init))
- nodesToWindowSpecification(spec) match {
- case reference: WindowSpecReference =>
- UnresolvedWindowExpression(function, reference)
- case definition: WindowSpecDefinition =>
- WindowExpression(function, definition)
- }
-
- /* UDFs - Must be last otherwise will preempt built in functions */
- case Token("TOK_FUNCTION", Token(name, Nil) :: args) =>
- UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = false)
- // Aggregate function with DISTINCT keyword.
- case Token("TOK_FUNCTIONDI", Token(name, Nil) :: args) =>
- UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = true)
- case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) =>
- UnresolvedFunction(name, UnresolvedStar(None) :: Nil, isDistinct = false)
-
- /* Literals */
- case Token("TOK_NULL", Nil) => Literal.create(null, NullType)
- case Token(TRUE(), Nil) => Literal.create(true, BooleanType)
- case Token(FALSE(), Nil) => Literal.create(false, BooleanType)
- case Token("TOK_STRINGLITERALSEQUENCE", strings) =>
- Literal(strings.map(s => ParseUtils.unescapeSQLString(s.text)).mkString)
-
- case ast if ast.tokenType == SparkSqlParser.TinyintLiteral =>
- Literal.create(ast.text.substring(0, ast.text.length() - 1).toByte, ByteType)
-
- case ast if ast.tokenType == SparkSqlParser.SmallintLiteral =>
- Literal.create(ast.text.substring(0, ast.text.length() - 1).toShort, ShortType)
-
- case ast if ast.tokenType == SparkSqlParser.BigintLiteral =>
- Literal.create(ast.text.substring(0, ast.text.length() - 1).toLong, LongType)
-
- case ast if ast.tokenType == SparkSqlParser.DoubleLiteral =>
- Literal(ast.text.toDouble)
-
- case ast if ast.tokenType == SparkSqlParser.Number =>
- val text = ast.text
- text match {
- case INTEGRAL() =>
- BigDecimal(text) match {
- case v if v.isValidInt =>
- Literal(v.intValue())
- case v if v.isValidLong =>
- Literal(v.longValue())
- case v => Literal(v.underlying())
- }
- case DECIMAL(_*) =>
- Literal(BigDecimal(text).underlying())
- case _ =>
- // Convert a scientifically notated decimal into a double.
- Literal(text.toDouble)
- }
- case ast if ast.tokenType == SparkSqlParser.StringLiteral =>
- Literal(ParseUtils.unescapeSQLString(ast.text))
-
- case ast if ast.tokenType == SparkSqlParser.TOK_DATELITERAL =>
- Literal(Date.valueOf(ast.text.substring(1, ast.text.length - 1)))
-
- case ast if ast.tokenType == SparkSqlParser.TOK_INTERVAL_YEAR_MONTH_LITERAL =>
- Literal(CalendarInterval.fromYearMonthString(ast.children.head.text))
-
- case ast if ast.tokenType == SparkSqlParser.TOK_INTERVAL_DAY_TIME_LITERAL =>
- Literal(CalendarInterval.fromDayTimeString(ast.children.head.text))
-
- case Token("TOK_INTERVAL", elements) =>
- var interval = new CalendarInterval(0, 0)
- var updated = false
- elements.foreach {
- // The interval node will always contain children for all possible time units. A child node
- // is only useful when it contains exactly one (numeric) child.
- case e @ Token(name, Token(value, Nil) :: Nil) =>
- val unit = name match {
- case "TOK_INTERVAL_YEAR_LITERAL" => "year"
- case "TOK_INTERVAL_MONTH_LITERAL" => "month"
- case "TOK_INTERVAL_WEEK_LITERAL" => "week"
- case "TOK_INTERVAL_DAY_LITERAL" => "day"
- case "TOK_INTERVAL_HOUR_LITERAL" => "hour"
- case "TOK_INTERVAL_MINUTE_LITERAL" => "minute"
- case "TOK_INTERVAL_SECOND_LITERAL" => "second"
- case "TOK_INTERVAL_MILLISECOND_LITERAL" => "millisecond"
- case "TOK_INTERVAL_MICROSECOND_LITERAL" => "microsecond"
- case _ => noParseRule(s"Interval($name)", e)
- }
- interval = interval.add(CalendarInterval.fromSingleUnitString(unit, value))
- updated = true
- case _ =>
- }
- if (!updated) {
- throw new AnalysisException("at least one time unit should be given for interval literal")
- }
- Literal(interval)
-
- case _ =>
- noParseRule("Expression", node)
- }
-
- /* Case insensitive matches for Window Specification */
- val PRECEDING = "(?i)preceding".r
- val FOLLOWING = "(?i)following".r
- val CURRENT = "(?i)current".r
- protected def nodesToWindowSpecification(nodes: Seq[ASTNode]): WindowSpec = nodes match {
- case Token(windowName, Nil) :: Nil =>
- // Refer to a window spec defined in the window clause.
- WindowSpecReference(windowName)
- case Nil =>
- // OVER()
- WindowSpecDefinition(
- partitionSpec = Nil,
- orderSpec = Nil,
- frameSpecification = UnspecifiedFrame)
- case spec =>
- val (partitionClause :: rowFrame :: rangeFrame :: Nil) =
- getClauses(
- Seq(
- "TOK_PARTITIONINGSPEC",
- "TOK_WINDOWRANGE",
- "TOK_WINDOWVALUES"),
- spec)
-
- // Handle Partition By and Order By.
- val (partitionSpec, orderSpec) = partitionClause.map { partitionAndOrdering =>
- val (partitionByClause :: orderByClause :: sortByClause :: clusterByClause :: Nil) =
- getClauses(
- Seq("TOK_DISTRIBUTEBY", "TOK_ORDERBY", "TOK_SORTBY", "TOK_CLUSTERBY"),
- partitionAndOrdering.children)
-
- (partitionByClause, orderByClause.orElse(sortByClause), clusterByClause) match {
- case (Some(partitionByExpr), Some(orderByExpr), None) =>
- (partitionByExpr.children.map(nodeToExpr),
- orderByExpr.children.map(nodeToSortOrder))
- case (Some(partitionByExpr), None, None) =>
- (partitionByExpr.children.map(nodeToExpr), Nil)
- case (None, Some(orderByExpr), None) =>
- (Nil, orderByExpr.children.map(nodeToSortOrder))
- case (None, None, Some(clusterByExpr)) =>
- val expressions = clusterByExpr.children.map(nodeToExpr)
- (expressions, expressions.map(SortOrder(_, Ascending)))
- case _ =>
- noParseRule("Partition & Ordering", partitionAndOrdering)
- }
- }.getOrElse {
- (Nil, Nil)
- }
-
- // Handle Window Frame
- val windowFrame =
- if (rowFrame.isEmpty && rangeFrame.isEmpty) {
- UnspecifiedFrame
- } else {
- val frameType = rowFrame.map(_ => RowFrame).getOrElse(RangeFrame)
- def nodeToBoundary(node: ASTNode): FrameBoundary = node match {
- case Token(PRECEDING(), Token(count, Nil) :: Nil) =>
- if (count.toLowerCase() == "unbounded") {
- UnboundedPreceding
- } else {
- ValuePreceding(count.toInt)
- }
- case Token(FOLLOWING(), Token(count, Nil) :: Nil) =>
- if (count.toLowerCase() == "unbounded") {
- UnboundedFollowing
- } else {
- ValueFollowing(count.toInt)
- }
- case Token(CURRENT(), Nil) => CurrentRow
- case _ =>
- noParseRule("Window Frame Boundary", node)
- }
-
- rowFrame.orElse(rangeFrame).map { frame =>
- frame.children match {
- case precedingNode :: followingNode :: Nil =>
- SpecifiedWindowFrame(
- frameType,
- nodeToBoundary(precedingNode),
- nodeToBoundary(followingNode))
- case precedingNode :: Nil =>
- SpecifiedWindowFrame(frameType, nodeToBoundary(precedingNode), CurrentRow)
- case _ =>
- noParseRule("Window Frame", frame)
- }
- }.getOrElse(sys.error(s"If you see this, please file a bug report with your query."))
- }
-
- WindowSpecDefinition(partitionSpec, orderSpec, windowFrame)
- }
-
- protected def nodeToTransformation(
- node: ASTNode,
- child: LogicalPlan): Option[ScriptTransformation] = None
-
- val explode = "(?i)explode".r
- val jsonTuple = "(?i)json_tuple".r
- protected def nodeToGenerate(node: ASTNode, outer: Boolean, child: LogicalPlan): Generate = {
- val Token("TOK_SELECT", Token("TOK_SELEXPR", clauses) :: Nil) = node
-
- val alias = cleanIdentifier(getClause("TOK_TABALIAS", clauses).children.head.text)
-
- val generator = clauses.head match {
- case Token("TOK_FUNCTION", Token(explode(), Nil) :: childNode :: Nil) =>
- Explode(nodeToExpr(childNode))
- case Token("TOK_FUNCTION", Token(jsonTuple(), Nil) :: children) =>
- JsonTuple(children.map(nodeToExpr))
- case other =>
- nodeToGenerator(other)
- }
-
- val attributes = clauses.collect {
- case Token(a, Nil) => UnresolvedAttribute(cleanIdentifier(a.toLowerCase))
- }
-
- Generate(
- generator,
- join = true,
- outer = outer,
- Some(cleanIdentifier(alias.toLowerCase)),
- attributes,
- child)
- }
-
- protected def nodeToGenerator(node: ASTNode): Generator = noParseRule("Generator", node)
-
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala
index 21deb82107..0b570c9e42 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.parser
import scala.language.implicitConversions
import scala.util.matching.Regex
import scala.util.parsing.combinator.syntactical.StandardTokenParsers
+import scala.util.parsing.input.CharArrayReader._
import org.apache.spark.sql.types._
@@ -117,3 +118,69 @@ private[sql] object DataTypeParser {
/** The exception thrown from the [[DataTypeParser]]. */
private[sql] class DataTypeException(message: String) extends Exception(message)
+
+class SqlLexical extends scala.util.parsing.combinator.lexical.StdLexical {
+ case class DecimalLit(chars: String) extends Token {
+ override def toString: String = chars
+ }
+
+ /* This is a work around to support the lazy setting */
+ def initialize(keywords: Seq[String]): Unit = {
+ reserved.clear()
+ reserved ++= keywords
+ }
+
+ /* Normal the keyword string */
+ def normalizeKeyword(str: String): String = str.toLowerCase
+
+ delimiters += (
+ "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")",
+ ",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~", "<=>"
+ )
+
+ protected override def processIdent(name: String) = {
+ val token = normalizeKeyword(name)
+ if (reserved contains token) Keyword(token) else Identifier(name)
+ }
+
+ override lazy val token: Parser[Token] =
+ ( rep1(digit) ~ scientificNotation ^^ { case i ~ s => DecimalLit(i.mkString + s) }
+ | '.' ~> (rep1(digit) ~ scientificNotation) ^^
+ { case i ~ s => DecimalLit("0." + i.mkString + s) }
+ | rep1(digit) ~ ('.' ~> digit.*) ~ scientificNotation ^^
+ { case i1 ~ i2 ~ s => DecimalLit(i1.mkString + "." + i2.mkString + s) }
+ | digit.* ~ identChar ~ (identChar | digit).* ^^
+ { case first ~ middle ~ rest => processIdent((first ++ (middle :: rest)).mkString) }
+ | rep1(digit) ~ ('.' ~> digit.*).? ^^ {
+ case i ~ None => NumericLit(i.mkString)
+ case i ~ Some(d) => DecimalLit(i.mkString + "." + d.mkString)
+ }
+ | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^
+ { case chars => StringLit(chars mkString "") }
+ | '"' ~> chrExcept('"', '\n', EofCh).* <~ '"' ^^
+ { case chars => StringLit(chars mkString "") }
+ | '`' ~> chrExcept('`', '\n', EofCh).* <~ '`' ^^
+ { case chars => Identifier(chars mkString "") }
+ | EofCh ^^^ EOF
+ | '\'' ~> failure("unclosed string literal")
+ | '"' ~> failure("unclosed string literal")
+ | delim
+ | failure("illegal character")
+ )
+
+ override def identChar: Parser[Elem] = letter | elem('_')
+
+ private lazy val scientificNotation: Parser[String] =
+ (elem('e') | elem('E')) ~> (elem('+') | elem('-')).? ~ rep1(digit) ^^ {
+ case s ~ rest => "e" + s.mkString + rest.mkString
+ }
+
+ override def whitespace: Parser[Any] =
+ ( whitespaceChar
+ | '/' ~ '*' ~ comment
+ | '/' ~ '/' ~ chrExcept(EofCh, '\n').*
+ | '#' ~ chrExcept(EofCh, '\n').*
+ | '-' ~ '-' ~ chrExcept(EofCh, '\n').*
+ | '/' ~ '*' ~ failure("unclosed comment")
+ ).*
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
index 51cfc50130..d0132529f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
@@ -16,91 +16,106 @@
*/
package org.apache.spark.sql.catalyst.parser
-import scala.annotation.tailrec
-
-import org.antlr.runtime._
-import org.antlr.runtime.tree.CommonTree
+import org.antlr.v4.runtime._
+import org.antlr.v4.runtime.atn.PredictionMode
+import org.antlr.v4.runtime.misc.ParseCancellationException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.trees.Origin
+import org.apache.spark.sql.types.DataType
/**
- * The ParseDriver takes a SQL command and turns this into an AST.
- *
- * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver
+ * Base SQL parsing infrastructure.
*/
-object ParseDriver extends Logging {
- /** Create an LogicalPlan ASTNode from a SQL command. */
- def parsePlan(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser =>
- parser.statement().getTree
- }
+abstract class AbstractSqlParser extends ParserInterface with Logging {
- /** Create an Expression ASTNode from a SQL command. */
- def parseExpression(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser =>
- parser.singleNamedExpression().getTree
+ /** Creates/Resolves DataType for a given SQL string. */
+ def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
+ // TODO add this to the parser interface.
+ astBuilder.visitSingleDataType(parser.singleDataType())
}
- /** Create an TableIdentifier ASTNode from a SQL command. */
- def parseTableName(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser =>
- parser.singleTableName().getTree
+ /** Creates Expression for a given SQL string. */
+ override def parseExpression(sqlText: String): Expression = parse(sqlText) { parser =>
+ astBuilder.visitSingleExpression(parser.singleExpression())
}
- private def parse(
- command: String,
- conf: ParserConf)(
- toTree: SparkSqlParser => CommonTree): ASTNode = {
- logInfo(s"Parsing command: $command")
+ /** Creates TableIdentifier for a given SQL string. */
+ override def parseTableIdentifier(sqlText: String): TableIdentifier = parse(sqlText) { parser =>
+ astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier())
+ }
- // Setup error collection.
- val reporter = new ParseErrorReporter()
+ /** Creates LogicalPlan for a given SQL string. */
+ override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
+ astBuilder.visitSingleStatement(parser.singleStatement()) match {
+ case plan: LogicalPlan => plan
+ case _ => nativeCommand(sqlText)
+ }
+ }
- // Create lexer.
- val lexer = new SparkSqlLexer(new ANTLRNoCaseStringStream(command))
- val tokens = new TokenRewriteStream(lexer)
- lexer.configure(conf, reporter)
+ /** Get the builder (visitor) which converts a ParseTree into a AST. */
+ protected def astBuilder: AstBuilder
- // Create the parser.
- val parser = new SparkSqlParser(tokens)
- parser.configure(conf, reporter)
+ /** Create a native command, or fail when this is not supported. */
+ protected def nativeCommand(sqlText: String): LogicalPlan = {
+ val position = Origin(None, None)
+ throw new ParseException(Option(sqlText), "Unsupported SQL statement", position, position)
+ }
- try {
- val result = toTree(parser)
+ protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = {
+ logInfo(s"Parsing command: $command")
- // Check errors.
- reporter.checkForErrors()
+ val lexer = new SqlBaseLexer(new ANTLRNoCaseStringStream(command))
+ lexer.removeErrorListeners()
+ lexer.addErrorListener(ParseErrorListener)
- // Return the AST node from the result.
- logInfo(s"Parse completed.")
+ val tokenStream = new CommonTokenStream(lexer)
+ val parser = new SqlBaseParser(tokenStream)
+ parser.addParseListener(PostProcessor)
+ parser.removeErrorListeners()
+ parser.addErrorListener(ParseErrorListener)
- // Find the non null token tree in the result.
- @tailrec
- def nonNullToken(tree: CommonTree): CommonTree = {
- if (tree.token != null || tree.getChildCount == 0) tree
- else nonNullToken(tree.getChild(0).asInstanceOf[CommonTree])
+ try {
+ try {
+ // first, try parsing with potentially faster SLL mode
+ parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
+ toResult(parser)
}
- val tree = nonNullToken(result)
-
- // Make sure all boundaries are set.
- tree.setUnknownTokenBoundaries()
-
- // Construct the immutable AST.
- def createASTNode(tree: CommonTree): ASTNode = {
- val children = (0 until tree.getChildCount).map { i =>
- createASTNode(tree.getChild(i).asInstanceOf[CommonTree])
- }.toList
- ASTNode(tree.token, tree.getTokenStartIndex, tree.getTokenStopIndex, children, tokens)
+ catch {
+ case e: ParseCancellationException =>
+ // if we fail, parse with LL mode
+ tokenStream.reset() // rewind input stream
+ parser.reset()
+
+ // Try Again.
+ parser.getInterpreter.setPredictionMode(PredictionMode.LL)
+ toResult(parser)
}
- createASTNode(tree)
}
catch {
- case e: RecognitionException =>
- logInfo(s"Parse failed.")
- reporter.throwError(e)
+ case e: ParseException if e.command.isDefined =>
+ throw e
+ case e: ParseException =>
+ throw e.withCommand(command)
+ case e: AnalysisException =>
+ val position = Origin(e.line, e.startPosition)
+ throw new ParseException(Option(command), e.message, position, position)
}
}
}
/**
+ * Concrete SQL parser for Catalyst-only SQL statements.
+ */
+object CatalystSqlParser extends AbstractSqlParser {
+ val astBuilder = new AstBuilder
+}
+
+/**
* This string stream provides the lexer with upper case characters only. This greatly simplifies
* lexing the stream, while we can maintain the original command.
*
@@ -120,58 +135,104 @@ object ParseDriver extends Logging {
* have the ANTLRNoCaseStringStream implementation.
*/
-private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRStringStream(input) {
+private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRInputStream(input) {
override def LA(i: Int): Int = {
val la = super.LA(i)
- if (la == 0 || la == CharStream.EOF) la
+ if (la == 0 || la == IntStream.EOF) la
else Character.toUpperCase(la)
}
}
/**
- * Utility used by the Parser and the Lexer for error collection and reporting.
+ * The ParseErrorListener converts parse errors into AnalysisExceptions.
*/
-private[parser] class ParseErrorReporter {
- val errors = scala.collection.mutable.Buffer.empty[ParseError]
-
- def report(br: BaseRecognizer, re: RecognitionException, tokenNames: Array[String]): Unit = {
- errors += ParseError(br, re, tokenNames)
+case object ParseErrorListener extends BaseErrorListener {
+ override def syntaxError(
+ recognizer: Recognizer[_, _],
+ offendingSymbol: scala.Any,
+ line: Int,
+ charPositionInLine: Int,
+ msg: String,
+ e: RecognitionException): Unit = {
+ val position = Origin(Some(line), Some(charPositionInLine))
+ throw new ParseException(None, msg, position, position)
}
+}
- def checkForErrors(): Unit = {
- if (errors.nonEmpty) {
- val first = errors.head
- val e = first.re
- throwError(e.line, e.charPositionInLine, first.buildMessage().toString, errors.tail)
- }
+/**
+ * A [[ParseException]] is an [[AnalysisException]] that is thrown during the parse process. It
+ * contains fields and an extended error message that make reporting and diagnosing errors easier.
+ */
+class ParseException(
+ val command: Option[String],
+ message: String,
+ val start: Origin,
+ val stop: Origin) extends AnalysisException(message, start.line, start.startPosition) {
+
+ def this(message: String, ctx: ParserRuleContext) = {
+ this(Option(ParserUtils.command(ctx)),
+ message,
+ ParserUtils.position(ctx.getStart),
+ ParserUtils.position(ctx.getStop))
}
- def throwError(e: RecognitionException): Nothing = {
- throwError(e.line, e.charPositionInLine, e.toString, errors)
+ override def getMessage: String = {
+ val builder = new StringBuilder
+ builder ++= "\n" ++= message
+ start match {
+ case Origin(Some(l), Some(p)) =>
+ builder ++= s"(line $l, pos $p)\n"
+ command.foreach { cmd =>
+ val (above, below) = cmd.split("\n").splitAt(l)
+ builder ++= "\n== SQL ==\n"
+ above.foreach(builder ++= _ += '\n')
+ builder ++= (0 until p).map(_ => "-").mkString("") ++= "^^^\n"
+ below.foreach(builder ++= _ += '\n')
+ }
+ case _ =>
+ command.foreach { cmd =>
+ builder ++= "\n== SQL ==\n" ++= cmd
+ }
+ }
+ builder.toString
}
- private def throwError(
- line: Int,
- startPosition: Int,
- msg: String,
- errors: Seq[ParseError]): Nothing = {
- val b = new StringBuilder
- b.append(msg).append("\n")
- errors.foreach(error => error.buildMessage(b).append("\n"))
- throw new AnalysisException(b.toString, Option(line), Option(startPosition))
+ def withCommand(cmd: String): ParseException = {
+ new ParseException(Option(cmd), message, start, stop)
}
}
/**
- * Error collected during the parsing process.
- *
- * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseError
+ * The post-processor validates & cleans-up the parse tree during the parse process.
*/
-private[parser] case class ParseError(
- br: BaseRecognizer,
- re: RecognitionException,
- tokenNames: Array[String]) {
- def buildMessage(s: StringBuilder = new StringBuilder): StringBuilder = {
- s.append(br.getErrorHeader(re)).append(" ").append(br.getErrorMessage(re, tokenNames))
+case object PostProcessor extends SqlBaseBaseListener {
+
+ /** Remove the back ticks from an Identifier. */
+ override def exitQuotedIdentifier(ctx: SqlBaseParser.QuotedIdentifierContext): Unit = {
+ replaceTokenByIdentifier(ctx, 1) { token =>
+ // Remove the double back ticks in the string.
+ token.setText(token.getText.replace("``", "`"))
+ token
+ }
+ }
+
+ /** Treat non-reserved keywords as Identifiers. */
+ override def exitNonReserved(ctx: SqlBaseParser.NonReservedContext): Unit = {
+ replaceTokenByIdentifier(ctx, 0)(identity)
+ }
+
+ private def replaceTokenByIdentifier(
+ ctx: ParserRuleContext,
+ stripMargins: Int)(
+ f: CommonToken => CommonToken = identity): Unit = {
+ val parent = ctx.getParent
+ parent.removeLastChild()
+ val token = ctx.getChild(0).getPayload.asInstanceOf[Token]
+ parent.addChild(f(new CommonToken(
+ new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream),
+ SqlBaseParser.IDENTIFIER,
+ token.getChannel,
+ token.getStartIndex + stripMargins,
+ token.getStopIndex - stripMargins)))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
index 0c2e481954..cb9fefec8f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
@@ -14,166 +14,181 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.spark.sql.catalyst.parser
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.trees.CurrentOrigin
-import org.apache.spark.sql.types._
+import scala.collection.mutable.StringBuilder
+
+import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token}
+import org.antlr.v4.runtime.misc.Interval
+import org.antlr.v4.runtime.tree.TerminalNode
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
/**
- * A collection of utility methods and patterns for parsing query texts.
+ * A collection of utility methods for use during the parsing process.
*/
-// TODO: merge with ParseUtils
object ParserUtils {
-
- object Token {
- // Match on (text, children)
- def unapply(node: ASTNode): Some[(String, List[ASTNode])] = {
- CurrentOrigin.setPosition(node.line, node.positionInLine)
- node.pattern
- }
+ /** Get the command which created the token. */
+ def command(ctx: ParserRuleContext): String = {
+ command(ctx.getStart.getInputStream)
}
- private val escapedIdentifier = "`(.+)`".r
- private val doubleQuotedString = "\"([^\"]+)\"".r
- private val singleQuotedString = "'([^']+)'".r
-
- // Token patterns
- val COUNT = "(?i)COUNT".r
- val SUM = "(?i)SUM".r
- val AND = "(?i)AND".r
- val OR = "(?i)OR".r
- val NOT = "(?i)NOT".r
- val TRUE = "(?i)TRUE".r
- val FALSE = "(?i)FALSE".r
- val LIKE = "(?i)LIKE".r
- val RLIKE = "(?i)RLIKE".r
- val REGEXP = "(?i)REGEXP".r
- val IN = "(?i)IN".r
- val DIV = "(?i)DIV".r
- val BETWEEN = "(?i)BETWEEN".r
- val WHEN = "(?i)WHEN".r
- val CASE = "(?i)CASE".r
- val INTEGRAL = "[+-]?\\d+".r
- val DECIMAL = "[+-]?((\\d+(\\.\\d*)?)|(\\.\\d+))".r
-
- /**
- * Strip quotes, if any, from the string.
- */
- def unquoteString(str: String): String = {
- str match {
- case singleQuotedString(s) => s
- case doubleQuotedString(s) => s
- case other => other
- }
+ /** Get the command which created the token. */
+ def command(stream: CharStream): String = {
+ stream.getText(Interval.of(0, stream.size()))
}
- /**
- * Strip backticks, if any, from the string.
- */
- def cleanIdentifier(ident: String): String = {
- ident match {
- case escapedIdentifier(i) => i
- case plainIdent => plainIdent
- }
+ /** Get the code that creates the given node. */
+ def source(ctx: ParserRuleContext): String = {
+ val stream = ctx.getStart.getInputStream
+ stream.getText(Interval.of(ctx.getStart.getStartIndex, ctx.getStop.getStopIndex))
}
- def getClauses(
- clauseNames: Seq[String],
- nodeList: Seq[ASTNode]): Seq[Option[ASTNode]] = {
- var remainingNodes = nodeList
- val clauses = clauseNames.map { clauseName =>
- val (matches, nonMatches) = remainingNodes.partition(_.text.toUpperCase == clauseName)
- remainingNodes = nonMatches ++ (if (matches.nonEmpty) matches.tail else Nil)
- matches.headOption
- }
+ /** Get all the text which comes after the given rule. */
+ def remainder(ctx: ParserRuleContext): String = remainder(ctx.getStop)
- if (remainingNodes.nonEmpty) {
- sys.error(
- s"""Unhandled clauses: ${remainingNodes.map(_.treeString).mkString("\n")}.
- |You are likely trying to use an unsupported Hive feature."""".stripMargin)
- }
- clauses
+ /** Get all the text which comes after the given token. */
+ def remainder(token: Token): String = {
+ val stream = token.getInputStream
+ val interval = Interval.of(token.getStopIndex + 1, stream.size())
+ stream.getText(interval)
}
- def getClause(clauseName: String, nodeList: Seq[ASTNode]): ASTNode = {
- getClauseOption(clauseName, nodeList).getOrElse(sys.error(
- s"Expected clause $clauseName missing from ${nodeList.map(_.treeString).mkString("\n")}"))
+ /** Convert a string token into a string. */
+ def string(token: Token): String = unescapeSQLString(token.getText)
+
+ /** Convert a string node into a string. */
+ def string(node: TerminalNode): String = unescapeSQLString(node.getText)
+
+ /** Get the origin (line and position) of the token. */
+ def position(token: Token): Origin = {
+ Origin(Option(token.getLine), Option(token.getCharPositionInLine))
}
- def getClauseOption(clauseName: String, nodeList: Seq[ASTNode]): Option[ASTNode] = {
- nodeList.filter { case ast: ASTNode => ast.text == clauseName } match {
- case Seq(oneMatch) => Some(oneMatch)
- case Seq() => None
- case _ => sys.error(s"Found multiple instances of clause $clauseName")
+ /** Assert if a condition holds. If it doesn't throw a parse exception. */
+ def assert(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = {
+ if (!f) {
+ throw new ParseException(message, ctx)
}
}
- def extractTableIdent(tableNameParts: ASTNode): TableIdentifier = {
- tableNameParts.children.map {
- case Token(part, Nil) => cleanIdentifier(part)
- } match {
- case Seq(tableOnly) => TableIdentifier(tableOnly)
- case Seq(databaseName, table) => TableIdentifier(table, Some(databaseName))
- case other => sys.error("Hive only supports tables names like 'tableName' " +
- s"or 'databaseName.tableName', found '$other'")
+ /**
+ * Register the origin of the context. Any TreeNode created in the closure will be assigned the
+ * registered origin. This method restores the previously set origin after completion of the
+ * closure.
+ */
+ def withOrigin[T](ctx: ParserRuleContext)(f: => T): T = {
+ val current = CurrentOrigin.get
+ CurrentOrigin.set(position(ctx.getStart))
+ try {
+ f
+ } finally {
+ CurrentOrigin.set(current)
}
}
- def nodeToDataType(node: ASTNode): DataType = node match {
- case Token("TOK_DECIMAL", precision :: scale :: Nil) =>
- DecimalType(precision.text.toInt, scale.text.toInt)
- case Token("TOK_DECIMAL", precision :: Nil) =>
- DecimalType(precision.text.toInt, 0)
- case Token("TOK_DECIMAL", Nil) => DecimalType.USER_DEFAULT
- case Token("TOK_BIGINT", Nil) => LongType
- case Token("TOK_INT", Nil) => IntegerType
- case Token("TOK_TINYINT", Nil) => ByteType
- case Token("TOK_SMALLINT", Nil) => ShortType
- case Token("TOK_BOOLEAN", Nil) => BooleanType
- case Token("TOK_STRING", Nil) => StringType
- case Token("TOK_VARCHAR", Token(_, Nil) :: Nil) => StringType
- case Token("TOK_CHAR", Token(_, Nil) :: Nil) => StringType
- case Token("TOK_FLOAT", Nil) => FloatType
- case Token("TOK_DOUBLE", Nil) => DoubleType
- case Token("TOK_DATE", Nil) => DateType
- case Token("TOK_TIMESTAMP", Nil) => TimestampType
- case Token("TOK_BINARY", Nil) => BinaryType
- case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType))
- case Token("TOK_STRUCT", Token("TOK_TABCOLLIST", fields) :: Nil) =>
- StructType(fields.map(nodeToStructField))
- case Token("TOK_MAP", keyType :: valueType :: Nil) =>
- MapType(nodeToDataType(keyType), nodeToDataType(valueType))
- case _ =>
- noParseRule("DataType", node)
- }
+ /** Unescape baskslash-escaped string enclosed by quotes. */
+ def unescapeSQLString(b: String): String = {
+ var enclosure: Character = null
+ val sb = new StringBuilder(b.length())
+
+ def appendEscapedChar(n: Char) {
+ n match {
+ case '0' => sb.append('\u0000')
+ case '\'' => sb.append('\'')
+ case '"' => sb.append('\"')
+ case 'b' => sb.append('\b')
+ case 'n' => sb.append('\n')
+ case 'r' => sb.append('\r')
+ case 't' => sb.append('\t')
+ case 'Z' => sb.append('\u001A')
+ case '\\' => sb.append('\\')
+ // The following 2 lines are exactly what MySQL does TODO: why do we do this?
+ case '%' => sb.append("\\%")
+ case '_' => sb.append("\\_")
+ case _ => sb.append(n)
+ }
+ }
- def nodeToStructField(node: ASTNode): StructField = node match {
- case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: Nil) =>
- StructField(cleanIdentifier(fieldName), nodeToDataType(dataType), nullable = true)
- case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: comment :: Nil) =>
- val meta = new MetadataBuilder().putString("comment", unquoteString(comment.text)).build()
- StructField(cleanIdentifier(fieldName), nodeToDataType(dataType), nullable = true, meta)
- case _ =>
- noParseRule("StructField", node)
+ var i = 0
+ val strLength = b.length
+ while (i < strLength) {
+ val currentChar = b.charAt(i)
+ if (enclosure == null) {
+ if (currentChar == '\'' || currentChar == '\"') {
+ enclosure = currentChar
+ }
+ } else if (enclosure == currentChar) {
+ enclosure = null
+ } else if (currentChar == '\\') {
+
+ if ((i + 6 < strLength) && b.charAt(i + 1) == 'u') {
+ // \u0000 style character literals.
+
+ val base = i + 2
+ val code = (0 until 4).foldLeft(0) { (mid, j) =>
+ val digit = Character.digit(b.charAt(j + base), 16)
+ (mid << 4) + digit
+ }
+ sb.append(code.asInstanceOf[Char])
+ i += 5
+ } else if (i + 4 < strLength) {
+ // \000 style character literals.
+
+ val i1 = b.charAt(i + 1)
+ val i2 = b.charAt(i + 2)
+ val i3 = b.charAt(i + 3)
+
+ if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') && (i3 >= '0' && i3 <= '7')) {
+ val tmp = ((i3 - '0') + ((i2 - '0') << 3) + ((i1 - '0') << 6)).asInstanceOf[Char]
+ sb.append(tmp)
+ i += 3
+ } else {
+ appendEscapedChar(i1)
+ i += 1
+ }
+ } else if (i + 2 < strLength) {
+ // escaped character literals.
+ val n = b.charAt(i + 1)
+ appendEscapedChar(n)
+ i += 1
+ }
+ } else {
+ // non-escaped character literals.
+ sb.append(currentChar)
+ }
+ i += 1
+ }
+ sb.toString()
}
- /**
- * Throw an exception because we cannot parse the given node for some unexpected reason.
- */
- def parseFailed(msg: String, node: ASTNode): Nothing = {
- throw new AnalysisException(s"$msg: '${node.source}")
- }
+ /** Some syntactic sugar which makes it easier to work with optional clauses for LogicalPlans. */
+ implicit class EnhancedLogicalPlan(val plan: LogicalPlan) extends AnyVal {
+ /**
+ * Create a plan using the block of code when the given context exists. Otherwise return the
+ * original plan.
+ */
+ def optional(ctx: AnyRef)(f: => LogicalPlan): LogicalPlan = {
+ if (ctx != null) {
+ f
+ } else {
+ plan
+ }
+ }
- /**
- * Throw an exception because there are no rules to parse the node.
- */
- def noParseRule(msg: String, node: ASTNode): Nothing = {
- throw new NotImplementedError(
- s"[$msg]: No parse rules for ASTNode type: ${node.tokenType}, tree:\n${node.treeString}")
+ /**
+ * Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the
+ * passed function. The original plan is returned when the context does not exist.
+ */
+ def optionalMap[C <: ParserRuleContext](
+ ctx: C)(
+ f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = {
+ if (ctx != null) {
+ f(ctx, plan)
+ } else {
+ plan
+ }
+ }
}
-
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 9c927077d0..0065619135 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType
@@ -68,6 +69,9 @@ object PhysicalOperation extends PredicateHelper {
val substitutedCondition = substitute(aliases)(condition)
(fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases)
+ case BroadcastHint(child) =>
+ collectProjectsAndFilters(child)
+
case other =>
(None, Nil, other, Map.empty)
}
@@ -139,20 +143,20 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
}
/**
- * A pattern that collects the filter and inner joins.
- *
- * Filter
- * |
- * inner Join
- * / \ ----> (Seq(plan0, plan1, plan2), conditions)
- * Filter plan2
- * |
- * inner join
- * / \
- * plan0 plan1
- *
- * Note: This pattern currently only works for left-deep trees.
- */
+ * A pattern that collects the filter and inner joins.
+ *
+ * Filter
+ * |
+ * inner Join
+ * / \ ----> (Seq(plan0, plan1, plan2), conditions)
+ * Filter plan2
+ * |
+ * inner join
+ * / \
+ * plan0 plan1
+ *
+ * Note: This pattern currently only works for left-deep trees.
+ */
object ExtractFiltersAndInnerJoins extends PredicateHelper {
// flatten all inner joins, which are next to each other
@@ -216,3 +220,75 @@ object IntegerIndex {
case _ => None
}
}
+
+/**
+ * An extractor used when planning the physical execution of an aggregation. Compared with a logical
+ * aggregation, the following transformations are performed:
+ * - Unnamed grouping expressions are named so that they can be referred to across phases of
+ * aggregation
+ * - Aggregations that appear multiple times are deduplicated.
+ * - The compution of the aggregations themselves is separated from the final result. For example,
+ * the `count` in `count + 1` will be split into an [[AggregateExpression]] and a final
+ * computation that computes `count.resultAttribute + 1`.
+ */
+object PhysicalAggregation {
+ // groupingExpressions, aggregateExpressions, resultExpressions, child
+ type ReturnType =
+ (Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan)
+
+ def unapply(a: Any): Option[ReturnType] = a match {
+ case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
+ // A single aggregate expression might appear multiple times in resultExpressions.
+ // In order to avoid evaluating an individual aggregate function multiple times, we'll
+ // build a set of the distinct aggregate expressions and build a function which can
+ // be used to re-write expressions so that they reference the single copy of the
+ // aggregate function which actually gets computed.
+ val aggregateExpressions = resultExpressions.flatMap { expr =>
+ expr.collect {
+ case agg: AggregateExpression => agg
+ }
+ }.distinct
+
+ val namedGroupingExpressions = groupingExpressions.map {
+ case ne: NamedExpression => ne -> ne
+ // If the expression is not a NamedExpressions, we add an alias.
+ // So, when we generate the result of the operator, the Aggregate Operator
+ // can directly get the Seq of attributes representing the grouping expressions.
+ case other =>
+ val withAlias = Alias(other, other.toString)()
+ other -> withAlias
+ }
+ val groupExpressionMap = namedGroupingExpressions.toMap
+
+ // The original `resultExpressions` are a set of expressions which may reference
+ // aggregate expressions, grouping column values, and constants. When aggregate operator
+ // emits output rows, we will use `resultExpressions` to generate an output projection
+ // which takes the grouping columns and final aggregate result buffer as input.
+ // Thus, we must re-write the result expressions so that their attributes match up with
+ // the attributes of the final result projection's input row:
+ val rewrittenResultExpressions = resultExpressions.map { expr =>
+ expr.transformDown {
+ case ae: AggregateExpression =>
+ // The final aggregation buffer's attributes will be `finalAggregationAttributes`,
+ // so replace each aggregate expression by its corresponding attribute in the set:
+ ae.resultAttribute
+ case expression =>
+ // Since we're using `namedGroupingAttributes` to extract the grouping key
+ // columns, we need to replace grouping key expressions with their corresponding
+ // attributes. We do not rely on the equality check at here since attributes may
+ // differ cosmetically. Instead, we use semanticEquals.
+ groupExpressionMap.collectFirst {
+ case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+ }.getOrElse(expression)
+ }.asInstanceOf[NamedExpression]
+ }
+
+ Some((
+ namedGroupingExpressions.map(_._2),
+ aggregateExpressions,
+ rewrittenResultExpressions,
+ child))
+
+ case _ => None
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index d31164fe94..d4447ca32d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -44,25 +44,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
* returns a constraint of the form `isNotNull(a)`
*/
private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = {
- var isNotNullConstraints = Set.empty[Expression]
-
- // First, we propagate constraints if the condition consists of equality and ranges. For all
- // other cases, we return an empty set of constraints
- constraints.foreach {
- case EqualTo(l, r) =>
- isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
- case GreaterThan(l, r) =>
- isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
- case GreaterThanOrEqual(l, r) =>
- isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
- case LessThan(l, r) =>
- isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
- case LessThanOrEqual(l, r) =>
- isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
- case Not(EqualTo(l, r)) =>
- isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
- case _ => // No inference
- }
+ // First, we propagate constraints from the null intolerant expressions.
+ var isNotNullConstraints: Set[Expression] =
+ constraints.flatMap(scanNullIntolerantExpr).map(IsNotNull(_))
// Second, we infer additional constraints from non-nullable attributes that are part of the
// operator's output
@@ -73,6 +57,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
}
/**
+ * Recursively explores the expressions which are null intolerant and returns all attributes
+ * in these expressions.
+ */
+ private def scanNullIntolerantExpr(expr: Expression): Seq[Attribute] = expr match {
+ case a: Attribute => Seq(a)
+ case _: NullIntolerant | IsNotNull(_: NullIntolerant) =>
+ expr.children.flatMap(scanNullIntolerantExpr)
+ case _ => Seq.empty[Attribute]
+ }
+
+ /**
* Infers an additional set of constraints from a given set of equality constraints.
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
* additional constraint of the form `b = 5`
@@ -127,8 +122,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output))
/**
- * The set of all attributes that are produced by this node.
- */
+ * The set of all attributes that are produced by this node.
+ */
def producedAttributes: AttributeSet = AttributeSet.empty
/**
@@ -216,8 +211,10 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this
}
- /** Returns the result of running [[transformExpressions]] on this node
- * and all its children. */
+ /**
+ * Returns the result of running [[transformExpressions]] on this node
+ * and all its children.
+ */
def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = {
transform {
case q: QueryPlan[_] => q.transformExpressions(rule).asInstanceOf[PlanType]
@@ -315,18 +312,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
/** Args that have cleaned such that differences in expression id should not affect equality */
protected lazy val cleanArgs: Seq[Any] = {
def cleanArg(arg: Any): Any = arg match {
+ // Children are checked using sameResult above.
+ case tn: TreeNode[_] if containsChild(tn) => null
case e: Expression => cleanExpression(e).canonicalized
case other => other
}
productIterator.map {
- // Children are checked using sameResult above.
- case tn: TreeNode[_] if containsChild(tn) => null
- case e: Expression => cleanArg(e)
case s: Option[_] => s.map(cleanArg)
case s: Seq[_] => s.map(cleanArg)
case m: Map[_, _] => m.mapValues(cleanArg)
- case other => other
+ case other => cleanArg(other)
}.toSeq
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
index 9ca4f13dd7..13f57c54a5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
@@ -26,13 +26,15 @@ object JoinType {
case "leftouter" | "left" => LeftOuter
case "rightouter" | "right" => RightOuter
case "leftsemi" => LeftSemi
+ case "leftanti" => LeftAnti
case _ =>
val supported = Seq(
"inner",
"outer", "full", "fullouter",
"leftouter", "left",
"rightouter", "right",
- "leftsemi")
+ "leftsemi",
+ "leftanti")
throw new IllegalArgumentException(s"Unsupported join type '$typ'. " +
"Supported join types include: " + supported.mkString("'", "', '", "'") + ".")
@@ -63,6 +65,10 @@ case object LeftSemi extends JoinType {
override def sql: String = "LEFT SEMI"
}
+case object LeftAnti extends JoinType {
+ override def sql: String = "LEFT ANTI"
+}
+
case class NaturalJoin(tpe: JoinType) extends JoinType {
require(Seq(Inner, LeftOuter, RightOuter, FullOuter).contains(tpe),
"Unsupported natural join type " + tpe)
@@ -70,7 +76,14 @@ case class NaturalJoin(tpe: JoinType) extends JoinType {
}
case class UsingJoin(tpe: JoinType, usingColumns: Seq[UnresolvedAttribute]) extends JoinType {
- require(Seq(Inner, LeftOuter, LeftSemi, RightOuter, FullOuter).contains(tpe),
+ require(Seq(Inner, LeftOuter, LeftSemi, RightOuter, FullOuter, LeftAnti).contains(tpe),
"Unsupported using join type " + tpe)
override def sql: String = "USING " + tpe.sql
}
+
+object LeftExistence {
+ def unapply(joinType: JoinType): Option[JoinType] = joinType match {
+ case LeftSemi | LeftAnti => Some(joinType)
+ case _ => None
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index ecf4285c46..aceeb8aadc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -79,13 +79,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
/**
* Computes [[Statistics]] for this plan. The default implementation assumes the output
- * cardinality is the product of of all child plan's cardinality, i.e. applies in the case
+ * cardinality is the product of all child plan's cardinality, i.e. applies in the case
* of cartesian joins.
*
* [[LeafNode]]s must override this.
*/
def statistics: Statistics = {
- if (children.size == 0) {
+ if (children.isEmpty) {
throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.")
}
Statistics(sizeInBytes = children.map(_.statistics.sizeInBytes).product)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 09c200fa83..d4fc9e4da9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -236,10 +236,24 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan {
})
}
+ private def merge(a: Set[Expression], b: Set[Expression]): Set[Expression] = {
+ val common = a.intersect(b)
+ // The constraint with only one reference could be easily inferred as predicate
+ // Grouping the constraints by it's references so we can combine the constraints with same
+ // reference together
+ val othera = a.diff(common).filter(_.references.size == 1).groupBy(_.references.head)
+ val otherb = b.diff(common).filter(_.references.size == 1).groupBy(_.references.head)
+ // loose the constraints by: A1 && B1 || A2 && B2 -> (A1 || A2) && (B1 || B2)
+ val others = (othera.keySet intersect otherb.keySet).map { attr =>
+ Or(othera(attr).reduceLeft(And), otherb(attr).reduceLeft(And))
+ }
+ common ++ others
+ }
+
override protected def validConstraints: Set[Expression] = {
children
.map(child => rewriteConstraints(children.head.output, child.output, child.constraints))
- .reduce(_ intersect _)
+ .reduce(merge(_, _))
}
}
@@ -252,7 +266,7 @@ case class Join(
override def output: Seq[Attribute] = {
joinType match {
- case LeftSemi =>
+ case LeftExistence(_) =>
left.output
case LeftOuter =>
left.output ++ right.output.map(_.withNullability(true))
@@ -276,7 +290,7 @@ case class Join(
.union(splitConjunctivePredicates(condition.get).toSet)
case Inner =>
left.constraints.union(right.constraints)
- case LeftSemi =>
+ case LeftExistence(_) =>
left.constraints
case LeftOuter =>
left.constraints
@@ -519,7 +533,6 @@ case class Expand(
projections: Seq[Seq[Expression]],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
-
override def references: AttributeSet =
AttributeSet(projections.flatten.flatMap(_.references))
@@ -527,6 +540,10 @@ case class Expand(
val sizeInBytes = super.statistics.sizeInBytes * projections.length
Statistics(sizeInBytes = sizeInBytes)
}
+
+ // This operator can reuse attributes (for example making them null when doing a roll up) so
+ // the contraints of the child may no longer be valid.
+ override protected def validConstraints: Set[Expression] = Set.empty[Expression]
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index da7f81c785..6df46189b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -18,9 +18,45 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.{ObjectType, StructType}
+import org.apache.spark.sql.types.{DataType, ObjectType, StructType}
+
+object CatalystSerde {
+ def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = {
+ val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer)
+ DeserializeToObject(Alias(deserializer, "obj")(), child)
+ }
+
+ def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = {
+ SerializeFromObject(encoderFor[T].namedExpressions, child)
+ }
+}
+
+/**
+ * Takes the input row from child and turns it into object using the given deserializer expression.
+ * The output of this operator is a single-field safe row containing the deserialized object.
+ */
+case class DeserializeToObject(
+ deserializer: Alias,
+ child: LogicalPlan) extends UnaryNode {
+ override def output: Seq[Attribute] = deserializer.toAttribute :: Nil
+
+ def outputObjectType: DataType = deserializer.dataType
+}
+
+/**
+ * Takes the input object from child and turns in into unsafe row using the given serializer
+ * expression. The output of its child must be a single-field row containing the input object.
+ */
+case class SerializeFromObject(
+ serializer: Seq[NamedExpression],
+ child: LogicalPlan) extends UnaryNode {
+ override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+
+ def inputObjectType: DataType = child.output.head.dataType
+}
/**
* A trait for logical operators that apply user defined functions to domain objects.
@@ -33,13 +69,6 @@ trait ObjectOperator extends LogicalPlan {
override def output: Seq[Attribute] = serializer.map(_.toAttribute)
/**
- * An [[ObjectOperator]] may have one or more deserializers to convert internal rows to objects.
- * It must also provide the attributes that are available during the resolution of each
- * deserializer.
- */
- def deserializers: Seq[(Expression, Seq[Attribute])]
-
- /**
* The object type that is produced by the user defined function. Note that the return type here
* is the same whether or not the operator is output serialized data.
*/
@@ -71,7 +100,7 @@ object MapPartitions {
child: LogicalPlan): MapPartitions = {
MapPartitions(
func.asInstanceOf[Iterator[Any] => Iterator[Any]],
- encoderFor[T].fromRowExpression,
+ UnresolvedDeserializer(encoderFor[T].deserializer),
encoderFor[U].namedExpressions,
child)
}
@@ -87,10 +116,32 @@ case class MapPartitions(
func: Iterator[Any] => Iterator[Any],
deserializer: Expression,
serializer: Seq[NamedExpression],
- child: LogicalPlan) extends UnaryNode with ObjectOperator {
- override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output)
+ child: LogicalPlan) extends UnaryNode with ObjectOperator
+
+object MapElements {
+ def apply[T : Encoder, U : Encoder](
+ func: AnyRef,
+ child: LogicalPlan): MapElements = {
+ MapElements(
+ func,
+ UnresolvedDeserializer(encoderFor[T].deserializer),
+ encoderFor[U].namedExpressions,
+ child)
+ }
}
+/**
+ * A relation produced by applying `func` to each element of the `child`.
+ *
+ * @param deserializer used to extract the input to `func` from an input row.
+ * @param serializer use to serialize the output of `func`.
+ */
+case class MapElements(
+ func: AnyRef,
+ deserializer: Expression,
+ serializer: Seq[NamedExpression],
+ child: LogicalPlan) extends UnaryNode with ObjectOperator
+
/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumns {
def apply[T : Encoder, U : Encoder](
@@ -98,7 +149,7 @@ object AppendColumns {
child: LogicalPlan): AppendColumns = {
new AppendColumns(
func.asInstanceOf[Any => Any],
- encoderFor[T].fromRowExpression,
+ UnresolvedDeserializer(encoderFor[T].deserializer),
encoderFor[U].namedExpressions,
child)
}
@@ -120,8 +171,6 @@ case class AppendColumns(
override def output: Seq[Attribute] = child.output ++ newColumns
def newColumns: Seq[Attribute] = serializer.map(_.toAttribute)
-
- override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output)
}
/** Factory for constructing new `MapGroups` nodes. */
@@ -133,8 +182,8 @@ object MapGroups {
child: LogicalPlan): MapGroups = {
new MapGroups(
func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]],
- encoderFor[K].fromRowExpression,
- encoderFor[T].fromRowExpression,
+ UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
+ UnresolvedDeserializer(encoderFor[T].deserializer, dataAttributes),
encoderFor[U].namedExpressions,
groupingAttributes,
dataAttributes,
@@ -158,11 +207,7 @@ case class MapGroups(
serializer: Seq[NamedExpression],
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
- child: LogicalPlan) extends UnaryNode with ObjectOperator {
-
- override def deserializers: Seq[(Expression, Seq[Attribute])] =
- Seq(keyDeserializer -> groupingAttributes, valueDeserializer -> dataAttributes)
-}
+ child: LogicalPlan) extends UnaryNode with ObjectOperator
/** Factory for constructing new `CoGroup` nodes. */
object CoGroup {
@@ -170,22 +215,24 @@ object CoGroup {
func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
- leftData: Seq[Attribute],
- rightData: Seq[Attribute],
+ leftAttr: Seq[Attribute],
+ rightAttr: Seq[Attribute],
left: LogicalPlan,
right: LogicalPlan): CoGroup = {
require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup))
CoGroup(
func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]],
- encoderFor[Key].fromRowExpression,
- encoderFor[Left].fromRowExpression,
- encoderFor[Right].fromRowExpression,
+ // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to
+ // resolve the `keyDeserializer` based on either of them, here we pick the left one.
+ UnresolvedDeserializer(encoderFor[Key].deserializer, leftGroup),
+ UnresolvedDeserializer(encoderFor[Left].deserializer, leftAttr),
+ UnresolvedDeserializer(encoderFor[Right].deserializer, rightAttr),
encoderFor[Result].namedExpressions,
leftGroup,
rightGroup,
- leftData,
- rightData,
+ leftAttr,
+ rightAttr,
left,
right)
}
@@ -206,10 +253,4 @@ case class CoGroup(
leftAttr: Seq[Attribute],
rightAttr: Seq[Attribute],
left: LogicalPlan,
- right: LogicalPlan) extends BinaryNode with ObjectOperator {
-
- override def deserializers: Seq[(Expression, Seq[Attribute])] =
- // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to resolve
- // the `keyDeserializer` based on either of them, here we pick the left one.
- Seq(keyDeserializer -> leftGroup, leftDeserializer -> leftAttr, rightDeserializer -> rightAttr)
-}
+ right: LogicalPlan) extends BinaryNode with ObjectOperator
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index be9f1ffa22..d449088498 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -76,9 +76,9 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
}
/**
- * Represents data where tuples are broadcasted to every node. It is quite common that the
- * entire set of tuples is transformed into different data structure.
- */
+ * Represents data where tuples are broadcasted to every node. It is quite common that the
+ * entire set of tuples is transformed into different data structure.
+ */
case class BroadcastDistribution(mode: BroadcastMode) extends Distribution
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 6b7997e903..232ca43588 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -22,6 +22,7 @@ import java.util.UUID
import scala.collection.Map
import scala.collection.mutable.Stack
+import org.apache.commons.lang.ClassUtils
import org.json4s.JsonAST._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
@@ -365,20 +366,32 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
* @param newArgs the new product arguments.
*/
def makeCopy(newArgs: Array[AnyRef]): BaseType = attachTree(this, "makeCopy") {
+ // Skip no-arg constructors that are just there for kryo.
val ctors = getClass.getConstructors.filter(_.getParameterTypes.size != 0)
if (ctors.isEmpty) {
sys.error(s"No valid constructor for $nodeName")
}
- val defaultCtor = ctors.maxBy(_.getParameterTypes.size)
+ val allArgs: Array[AnyRef] = if (otherCopyArgs.isEmpty) {
+ newArgs
+ } else {
+ newArgs ++ otherCopyArgs
+ }
+ val defaultCtor = ctors.find { ctor =>
+ if (ctor.getParameterTypes.length != allArgs.length) {
+ false
+ } else if (allArgs.contains(null)) {
+ // if there is a `null`, we can't figure out the class, therefore we should just fallback
+ // to older heuristic
+ false
+ } else {
+ val argsArray: Array[Class[_]] = allArgs.map(_.getClass)
+ ClassUtils.isAssignable(argsArray, ctor.getParameterTypes, true /* autoboxing */)
+ }
+ }.getOrElse(ctors.maxBy(_.getParameterTypes.length)) // fall back to older heuristic
try {
CurrentOrigin.withOrigin(origin) {
- // Skip no-arg constructors that are just there for kryo.
- if (otherCopyArgs.isEmpty) {
- defaultCtor.newInstance(newArgs: _*).asInstanceOf[BaseType]
- } else {
- defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[BaseType]
- }
+ defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType]
}
} catch {
case e: java.lang.IllegalArgumentException =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala
index 191d5e6399..d5d151a580 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala
@@ -41,4 +41,6 @@ class StringKeyHashMap[T](normalizer: (String) => String) {
def remove(key: String): Option[T] = base.remove(normalizer(key))
def iterator: Iterator[(String, T)] = base.toIterator
+
+ def clear(): Unit = base.clear()
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
index c2eeb3c565..cde8bd5b96 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.util
-import java.util.regex.Pattern
+import java.util.regex.{Pattern, PatternSyntaxException}
import org.apache.spark.unsafe.types.UTF8String
@@ -52,4 +52,25 @@ object StringUtils {
def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase)
def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase)
+
+ /**
+ * This utility can be used for filtering pattern in the "Like" of "Show Tables / Functions" DDL
+ * @param names the names list to be filtered
+ * @param pattern the filter pattern, only '*' and '|' are allowed as wildcards, others will
+ * follow regular expression convention, case insensitive match and white spaces
+ * on both ends will be ignored
+ * @return the filtered names list in order
+ */
+ def filterPattern(names: Seq[String], pattern: String): Seq[String] = {
+ val funcNames = scala.collection.mutable.SortedSet.empty[String]
+ pattern.trim().split("\\|").foreach { subPattern =>
+ try {
+ val regex = ("(?i)" + subPattern.replaceAll("\\*", ".*")).r
+ funcNames ++= names.filter{ name => regex.pattern.matcher(name).matches() }
+ } catch {
+ case _: PatternSyntaxException =>
+ }
+ }
+ funcNames.toSeq
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
index b11365b297..f879b34358 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
@@ -155,10 +155,13 @@ package object util {
/**
* Returns the string representation of this expression that is safe to be put in
- * code comments of generated code.
+ * code comments of generated code. The length is capped at 128 characters.
*/
- def toCommentSafeString(str: String): String =
- str.replace("*/", "\\*\\/").replace("\\u", "\\\\u")
+ def toCommentSafeString(str: String): String = {
+ val len = math.min(str.length, 128)
+ val suffix = if (str.length > len) "..." else ""
+ str.substring(0, len).replace("*/", "\\*\\/").replace("\\u", "\\\\u") + suffix
+ }
/* FIX ME
implicit class debugLogging(a: Any) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
index dabf9a2fc0..fb7251d71b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
@@ -23,7 +23,6 @@ import org.json4s.JsonDSL._
import org.apache.spark.annotation.DeveloperApi
/**
- * ::DeveloperApi::
* The data type for User Defined Types (UDTs).
*
* This interface allows a user to make their own classes more interoperable with SparkSQL;
@@ -35,8 +34,11 @@ import org.apache.spark.annotation.DeveloperApi
*
* The conversion via `serialize` occurs when instantiating a `DataFrame` from another RDD.
* The conversion via `deserialize` occurs when reading from a `DataFrame`.
+ *
+ * Note: This was previously a developer API in Spark 1.x. We are making this private in Spark 2.0
+ * because we will very likely create a new version of this that works better with Datasets.
*/
-@DeveloperApi
+private[spark]
abstract class UserDefinedType[UserType >: Null] extends DataType with Serializable {
/** Underlying storage type for this UDT */
diff --git a/sql/catalyst/src/test/resources/log4j.properties b/sql/catalyst/src/test/resources/log4j.properties
index eb3b1999eb..3706a6e361 100644
--- a/sql/catalyst/src/test/resources/log4j.properties
+++ b/sql/catalyst/src/test/resources/log4j.properties
@@ -24,5 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
-log4j.logger.org.spark-project.jetty=WARN
-org.spark-project.jetty.LEVEL=WARN
+log4j.logger.org.spark_project.jetty=WARN
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
index 8207d64798..711e870711 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
@@ -196,12 +196,11 @@ object RandomDataGenerator {
case ShortType => randomNumeric[Short](
rand, _.nextInt().toShort, Seq(Short.MinValue, Short.MaxValue, 0.toShort))
case NullType => Some(() => null)
- case ArrayType(elementType, containsNull) => {
+ case ArrayType(elementType, containsNull) =>
forType(elementType, nullable = containsNull, rand).map {
elementGenerator => () => Seq.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator())
}
- }
- case MapType(keyType, valueType, valueContainsNull) => {
+ case MapType(keyType, valueType, valueContainsNull) =>
for (
keyGenerator <- forType(keyType, nullable = false, rand);
valueGenerator <-
@@ -221,8 +220,7 @@ object RandomDataGenerator {
keys.zip(values).toMap
}
}
- }
- case StructType(fields) => {
+ case StructType(fields) =>
val maybeFieldGenerators: Seq[Option[() => Any]] = fields.map { field =>
forType(field.dataType, nullable = field.nullable, rand)
}
@@ -232,8 +230,7 @@ object RandomDataGenerator {
} else {
None
}
- }
- case udt: UserDefinedType[_] => {
+ case udt: UserDefinedType[_] =>
val maybeSqlTypeGenerator = forType(udt.sqlType, nullable, rand)
// Because random data generator at here returns scala value, we need to
// convert it to catalyst value to call udt's deserialize.
@@ -253,7 +250,6 @@ object RandomDataGenerator {
} else {
None
}
- }
case unsupportedType => None
}
// Handle nullability by wrapping the non-null value generator:
@@ -277,7 +273,7 @@ object RandomDataGenerator {
val fields = mutable.ArrayBuffer.empty[Any]
schema.fields.foreach { f =>
f.dataType match {
- case ArrayType(childType, nullable) => {
+ case ArrayType(childType, nullable) =>
val data = if (f.nullable && rand.nextFloat() <= PROBABILITY_OF_NULL) {
null
} else {
@@ -294,10 +290,8 @@ object RandomDataGenerator {
arr
}
fields += data
- }
- case StructType(children) => {
+ case StructType(children) =>
fields += randomRow(rand, StructType(children))
- }
case _ =>
val generator = RandomDataGenerator.forType(f.dataType, f.nullable, rand)
assert(generator.isDefined, "Unsupported type")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala
index d9577dea1b..c9c9599e7f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala
@@ -121,7 +121,7 @@ class RowTest extends FunSpec with Matchers {
externalRow should be theSameInstanceAs externalRow.copy()
}
- it("copy should return same ref for interal rows") {
+ it("copy should return same ref for internal rows") {
internalRow should be theSameInstanceAs internalRow.copy()
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index dd31050bb5..5ca5a72512 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -248,10 +248,10 @@ class ScalaReflectionSuite extends SparkFunSuite {
Seq(
("mirror", () => mirror),
("dataTypeFor", () => dataTypeFor[ComplexData]),
- ("constructorFor", () => constructorFor[ComplexData]),
+ ("constructorFor", () => deserializerFor[ComplexData]),
("extractorsFor", {
val inputObject = BoundReference(0, dataTypeForComplexData, nullable = false)
- () => extractorsFor[ComplexData](inputObject)
+ () => serializerFor[ComplexData](inputObject)
}),
("getConstructorParameters(cls)", () => getConstructorParameters(classOf[ComplexData])),
("getConstructorParameterNames", () => getConstructorParameterNames(classOf[ComplexData])),
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index a90dfc5039..ad101d1c40 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -272,6 +272,62 @@ class AnalysisErrorSuite extends AnalysisTest {
testRelation2.where('bad_column > 1).groupBy('a)(UnresolvedAlias(max('b))),
"cannot resolve '`bad_column`'" :: Nil)
+ errorTest(
+ "slide duration greater than window in time window",
+ testRelation2.select(
+ TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "2 second", "0 second").as("window")),
+ s"The slide duration " :: " must be less than or equal to the windowDuration " :: Nil
+ )
+
+ errorTest(
+ "start time greater than slide duration in time window",
+ testRelation.select(
+ TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "1 minute").as("window")),
+ "The start time " :: " must be less than the slideDuration " :: Nil
+ )
+
+ errorTest(
+ "start time equal to slide duration in time window",
+ testRelation.select(
+ TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "1 second").as("window")),
+ "The start time " :: " must be less than the slideDuration " :: Nil
+ )
+
+ errorTest(
+ "negative window duration in time window",
+ testRelation.select(
+ TimeWindow(Literal("2016-01-01 01:01:01"), "-1 second", "1 second", "0 second").as("window")),
+ "The window duration " :: " must be greater than 0." :: Nil
+ )
+
+ errorTest(
+ "zero window duration in time window",
+ testRelation.select(
+ TimeWindow(Literal("2016-01-01 01:01:01"), "0 second", "1 second", "0 second").as("window")),
+ "The window duration " :: " must be greater than 0." :: Nil
+ )
+
+ errorTest(
+ "negative slide duration in time window",
+ testRelation.select(
+ TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "-1 second", "0 second").as("window")),
+ "The slide duration " :: " must be greater than 0." :: Nil
+ )
+
+ errorTest(
+ "zero slide duration in time window",
+ testRelation.select(
+ TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "0 second", "0 second").as("window")),
+ "The slide duration" :: " must be greater than 0." :: Nil
+ )
+
+ errorTest(
+ "negative start time in time window",
+ testRelation.select(
+ TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "-5 second").as("window")),
+ "The start time" :: "must be greater than or equal to 0." :: Nil
+ )
+
test("SPARK-6452 regression test") {
// CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s)
// Since we manually construct the logical plan at here and Sum only accept
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
index 6fa4beed99..b1fcf011f4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
@@ -30,9 +30,9 @@ trait AnalysisTest extends PlanTest {
private def makeAnalyzer(caseSensitive: Boolean): Analyzer = {
val conf = new SimpleCatalystConf(caseSensitive)
- val catalog = new SessionCatalog(new InMemoryCatalog, conf)
- catalog.createTempTable("TaBlE", TestRelations.testRelation, ignoreIfExists = true)
- new Analyzer(catalog, EmptyFunctionRegistry, conf) {
+ val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
+ catalog.createTempTable("TaBlE", TestRelations.testRelation, overrideIfExists = true)
+ new Analyzer(catalog, conf) {
override val extendedResolutionRules = EliminateSubqueryAliases :: Nil
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
index 31501864a8..b3b1f5b920 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
@@ -32,8 +32,8 @@ import org.apache.spark.sql.types._
class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter {
private val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true)
- private val catalog = new SessionCatalog(new InMemoryCatalog, conf)
- private val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf)
+ private val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
+ private val analyzer = new Analyzer(catalog, conf)
private val relation = LocalRelation(
AttributeReference("i", IntegerType)(),
@@ -52,7 +52,7 @@ class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter {
private val b: Expression = UnresolvedAttribute("b")
before {
- catalog.createTempTable("table", relation, ignoreIfExists = true)
+ catalog.createTempTable("table", relation, overrideIfExists = true)
}
private def checkType(expression: Expression, expectedType: DataType): Unit = {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala
index 277c2d717e..f961fe3292 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala
@@ -149,6 +149,15 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach {
// Tables
// --------------------------------------------------------------------------
+ test("the table type of an external table should be EXTERNAL_TABLE") {
+ val catalog = newBasicCatalog()
+ val table =
+ newTable("external_table1", "db2").copy(tableType = CatalogTableType.EXTERNAL_TABLE)
+ catalog.createTable("db2", table, ignoreIfExists = false)
+ val actual = catalog.getTable("db2", "external_table1")
+ assert(actual.tableType === CatalogTableType.EXTERNAL_TABLE)
+ }
+
test("drop table") {
val catalog = newBasicCatalog()
assert(catalog.listTables("db2").toSet == Set("tbl1", "tbl2"))
@@ -210,7 +219,7 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach {
}
test("get table") {
- assert(newBasicCatalog().getTable("db2", "tbl1").name.table == "tbl1")
+ assert(newBasicCatalog().getTable("db2", "tbl1").identifier.table == "tbl1")
}
test("get table when database/table does not exist") {
@@ -272,31 +281,37 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach {
test("drop partitions") {
val catalog = newBasicCatalog()
assert(catalogPartitionsEqual(catalog, "db2", "tbl2", Seq(part1, part2)))
- catalog.dropPartitions("db2", "tbl2", Seq(part1.spec), ignoreIfNotExists = false)
+ catalog.dropPartitions(
+ "db2", "tbl2", Seq(part1.spec), ignoreIfNotExists = false)
assert(catalogPartitionsEqual(catalog, "db2", "tbl2", Seq(part2)))
resetState()
val catalog2 = newBasicCatalog()
assert(catalogPartitionsEqual(catalog2, "db2", "tbl2", Seq(part1, part2)))
- catalog2.dropPartitions("db2", "tbl2", Seq(part1.spec, part2.spec), ignoreIfNotExists = false)
+ catalog2.dropPartitions(
+ "db2", "tbl2", Seq(part1.spec, part2.spec), ignoreIfNotExists = false)
assert(catalog2.listPartitions("db2", "tbl2").isEmpty)
}
test("drop partitions when database/table does not exist") {
val catalog = newBasicCatalog()
intercept[AnalysisException] {
- catalog.dropPartitions("does_not_exist", "tbl1", Seq(), ignoreIfNotExists = false)
+ catalog.dropPartitions(
+ "does_not_exist", "tbl1", Seq(), ignoreIfNotExists = false)
}
intercept[AnalysisException] {
- catalog.dropPartitions("db2", "does_not_exist", Seq(), ignoreIfNotExists = false)
+ catalog.dropPartitions(
+ "db2", "does_not_exist", Seq(), ignoreIfNotExists = false)
}
}
test("drop partitions that do not exist") {
val catalog = newBasicCatalog()
intercept[AnalysisException] {
- catalog.dropPartitions("db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = false)
+ catalog.dropPartitions(
+ "db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = false)
}
- catalog.dropPartitions("db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = true)
+ catalog.dropPartitions(
+ "db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = true)
}
test("get partition") {
@@ -433,7 +448,8 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach {
test("get function") {
val catalog = newBasicCatalog()
assert(catalog.getFunction("db2", "func1") ==
- CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass))
+ CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass,
+ Seq.empty[(String, String)]))
intercept[AnalysisException] {
catalog.getFunction("db2", "does_not_exist")
}
@@ -452,7 +468,7 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach {
assert(catalog.getFunction("db2", "func1").className == funcClass)
catalog.renameFunction("db2", "func1", newName)
intercept[AnalysisException] { catalog.getFunction("db2", "func1") }
- assert(catalog.getFunction("db2", newName).name.funcName == newName)
+ assert(catalog.getFunction("db2", newName).identifier.funcName == newName)
assert(catalog.getFunction("db2", newName).className == funcClass)
intercept[AnalysisException] { catalog.renameFunction("db2", "does_not_exist", "me") }
}
@@ -464,21 +480,6 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach {
}
}
- test("alter function") {
- val catalog = newBasicCatalog()
- assert(catalog.getFunction("db2", "func1").className == funcClass)
- catalog.alterFunction("db2", newFunc("func1").copy(className = "muhaha"))
- assert(catalog.getFunction("db2", "func1").className == "muhaha")
- intercept[AnalysisException] { catalog.alterFunction("db2", newFunc("funcky")) }
- }
-
- test("alter function when database does not exist") {
- val catalog = newBasicCatalog()
- intercept[AnalysisException] {
- catalog.alterFunction("does_not_exist", newFunc())
- }
- }
-
test("list functions") {
val catalog = newBasicCatalog()
catalog.createFunction("db2", newFunc("func2"))
@@ -549,15 +550,19 @@ abstract class CatalogTestUtils {
def newTable(name: String, database: Option[String] = None): CatalogTable = {
CatalogTable(
- name = TableIdentifier(name, database),
+ identifier = TableIdentifier(name, database),
tableType = CatalogTableType.EXTERNAL_TABLE,
storage = storageFormat,
- schema = Seq(CatalogColumn("col1", "int"), CatalogColumn("col2", "string")),
- partitionColumns = Seq(CatalogColumn("a", "int"), CatalogColumn("b", "string")))
+ schema = Seq(
+ CatalogColumn("col1", "int"),
+ CatalogColumn("col2", "string"),
+ CatalogColumn("a", "int"),
+ CatalogColumn("b", "string")),
+ partitionColumnNames = Seq("a", "b"))
}
def newFunc(name: String, database: Option[String] = None): CatalogFunction = {
- CatalogFunction(FunctionIdentifier(name, database), funcClass)
+ CatalogFunction(FunctionIdentifier(name, database), funcClass, Seq.empty[(String, String)])
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
index 74e995cc5b..426273e1e3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.catalog
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{Range, SubqueryAlias}
@@ -61,7 +62,7 @@ class SessionCatalogSuite extends SparkFunSuite {
test("get database when a database exists") {
val catalog = new SessionCatalog(newBasicCatalog())
- val db1 = catalog.getDatabase("db1")
+ val db1 = catalog.getDatabaseMetadata("db1")
assert(db1.name == "db1")
assert(db1.description.contains("db1"))
}
@@ -69,7 +70,7 @@ class SessionCatalogSuite extends SparkFunSuite {
test("get database should throw exception when the database does not exist") {
val catalog = new SessionCatalog(newBasicCatalog())
intercept[AnalysisException] {
- catalog.getDatabase("db_that_does_not_exist")
+ catalog.getDatabaseMetadata("db_that_does_not_exist")
}
}
@@ -127,10 +128,10 @@ class SessionCatalogSuite extends SparkFunSuite {
test("alter database") {
val catalog = new SessionCatalog(newBasicCatalog())
- val db1 = catalog.getDatabase("db1")
+ val db1 = catalog.getDatabaseMetadata("db1")
// Note: alter properties here because Hive does not support altering other fields
catalog.alterDatabase(db1.copy(properties = Map("k" -> "v3", "good" -> "true")))
- val newDb1 = catalog.getDatabase("db1")
+ val newDb1 = catalog.getDatabaseMetadata("db1")
assert(db1.properties.isEmpty)
assert(newDb1.properties.size == 2)
assert(newDb1.properties.get("k") == Some("v3"))
@@ -197,17 +198,17 @@ class SessionCatalogSuite extends SparkFunSuite {
val catalog = new SessionCatalog(newBasicCatalog())
val tempTable1 = Range(1, 10, 1, 10, Seq())
val tempTable2 = Range(1, 20, 2, 10, Seq())
- catalog.createTempTable("tbl1", tempTable1, ignoreIfExists = false)
- catalog.createTempTable("tbl2", tempTable2, ignoreIfExists = false)
+ catalog.createTempTable("tbl1", tempTable1, overrideIfExists = false)
+ catalog.createTempTable("tbl2", tempTable2, overrideIfExists = false)
assert(catalog.getTempTable("tbl1") == Some(tempTable1))
assert(catalog.getTempTable("tbl2") == Some(tempTable2))
assert(catalog.getTempTable("tbl3") == None)
// Temporary table already exists
intercept[AnalysisException] {
- catalog.createTempTable("tbl1", tempTable1, ignoreIfExists = false)
+ catalog.createTempTable("tbl1", tempTable1, overrideIfExists = false)
}
// Temporary table already exists but we override it
- catalog.createTempTable("tbl1", tempTable2, ignoreIfExists = true)
+ catalog.createTempTable("tbl1", tempTable2, overrideIfExists = true)
assert(catalog.getTempTable("tbl1") == Some(tempTable2))
}
@@ -232,10 +233,9 @@ class SessionCatalogSuite extends SparkFunSuite {
intercept[AnalysisException] {
catalog.dropTable(TableIdentifier("tbl1", Some("unknown_db")), ignoreIfNotExists = true)
}
- // Table does not exist
- intercept[AnalysisException] {
- catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = false)
- }
+ // If the table does not exist, we do not issue an exception. Instead, we output an error log
+ // message to console when ignoreIfNotExists is set to false.
+ catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = false)
catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = true)
}
@@ -243,7 +243,7 @@ class SessionCatalogSuite extends SparkFunSuite {
val externalCatalog = newBasicCatalog()
val sessionCatalog = new SessionCatalog(externalCatalog)
val tempTable = Range(1, 10, 2, 10, Seq())
- sessionCatalog.createTempTable("tbl1", tempTable, ignoreIfExists = false)
+ sessionCatalog.createTempTable("tbl1", tempTable, overrideIfExists = false)
sessionCatalog.setCurrentDatabase("db2")
assert(sessionCatalog.getTempTable("tbl1") == Some(tempTable))
assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2"))
@@ -255,7 +255,7 @@ class SessionCatalogSuite extends SparkFunSuite {
sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false)
assert(externalCatalog.listTables("db2").toSet == Set("tbl2"))
// If database is specified, temp tables are never dropped
- sessionCatalog.createTempTable("tbl1", tempTable, ignoreIfExists = false)
+ sessionCatalog.createTempTable("tbl1", tempTable, overrideIfExists = false)
sessionCatalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = false)
sessionCatalog.dropTable(TableIdentifier("tbl1", Some("db2")), ignoreIfNotExists = false)
assert(sessionCatalog.getTempTable("tbl1") == Some(tempTable))
@@ -299,7 +299,7 @@ class SessionCatalogSuite extends SparkFunSuite {
val externalCatalog = newBasicCatalog()
val sessionCatalog = new SessionCatalog(externalCatalog)
val tempTable = Range(1, 10, 2, 10, Seq())
- sessionCatalog.createTempTable("tbl1", tempTable, ignoreIfExists = false)
+ sessionCatalog.createTempTable("tbl1", tempTable, overrideIfExists = false)
sessionCatalog.setCurrentDatabase("db2")
assert(sessionCatalog.getTempTable("tbl1") == Some(tempTable))
assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2"))
@@ -327,7 +327,7 @@ class SessionCatalogSuite extends SparkFunSuite {
assert(newTbl1.properties.get("toh") == Some("frem"))
// Alter table without explicitly specifying database
sessionCatalog.setCurrentDatabase("db2")
- sessionCatalog.alterTable(tbl1.copy(name = TableIdentifier("tbl1")))
+ sessionCatalog.alterTable(tbl1.copy(identifier = TableIdentifier("tbl1")))
val newestTbl1 = externalCatalog.getTable("db2", "tbl1")
assert(newestTbl1 == tbl1)
}
@@ -345,21 +345,21 @@ class SessionCatalogSuite extends SparkFunSuite {
test("get table") {
val externalCatalog = newBasicCatalog()
val sessionCatalog = new SessionCatalog(externalCatalog)
- assert(sessionCatalog.getTable(TableIdentifier("tbl1", Some("db2")))
+ assert(sessionCatalog.getTableMetadata(TableIdentifier("tbl1", Some("db2")))
== externalCatalog.getTable("db2", "tbl1"))
// Get table without explicitly specifying database
sessionCatalog.setCurrentDatabase("db2")
- assert(sessionCatalog.getTable(TableIdentifier("tbl1"))
+ assert(sessionCatalog.getTableMetadata(TableIdentifier("tbl1"))
== externalCatalog.getTable("db2", "tbl1"))
}
test("get table when database/table does not exist") {
val catalog = new SessionCatalog(newBasicCatalog())
intercept[AnalysisException] {
- catalog.getTable(TableIdentifier("tbl1", Some("unknown_db")))
+ catalog.getTableMetadata(TableIdentifier("tbl1", Some("unknown_db")))
}
intercept[AnalysisException] {
- catalog.getTable(TableIdentifier("unknown_table", Some("db2")))
+ catalog.getTableMetadata(TableIdentifier("unknown_table", Some("db2")))
}
}
@@ -368,7 +368,7 @@ class SessionCatalogSuite extends SparkFunSuite {
val sessionCatalog = new SessionCatalog(externalCatalog)
val tempTable1 = Range(1, 10, 1, 10, Seq())
val metastoreTable1 = externalCatalog.getTable("db2", "tbl1")
- sessionCatalog.createTempTable("tbl1", tempTable1, ignoreIfExists = false)
+ sessionCatalog.createTempTable("tbl1", tempTable1, overrideIfExists = false)
sessionCatalog.setCurrentDatabase("db2")
// If we explicitly specify the database, we'll look up the relation in that database
assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1", Some("db2")))
@@ -385,7 +385,7 @@ class SessionCatalogSuite extends SparkFunSuite {
test("lookup table relation with alias") {
val catalog = new SessionCatalog(newBasicCatalog())
val alias = "monster"
- val tableMetadata = catalog.getTable(TableIdentifier("tbl1", Some("db2")))
+ val tableMetadata = catalog.getTableMetadata(TableIdentifier("tbl1", Some("db2")))
val relation = SubqueryAlias("tbl1", CatalogRelation("db2", tableMetadata))
val relationWithAlias =
SubqueryAlias(alias,
@@ -406,7 +406,7 @@ class SessionCatalogSuite extends SparkFunSuite {
assert(!catalog.tableExists(TableIdentifier("tbl2", Some("db1"))))
// If database is explicitly specified, do not check temporary tables
val tempTable = Range(1, 10, 1, 10, Seq())
- catalog.createTempTable("tbl3", tempTable, ignoreIfExists = false)
+ catalog.createTempTable("tbl3", tempTable, overrideIfExists = false)
assert(!catalog.tableExists(TableIdentifier("tbl3", Some("db2"))))
// If database is not explicitly specified, check the current database
catalog.setCurrentDatabase("db2")
@@ -418,8 +418,8 @@ class SessionCatalogSuite extends SparkFunSuite {
test("list tables without pattern") {
val catalog = new SessionCatalog(newBasicCatalog())
val tempTable = Range(1, 10, 2, 10, Seq())
- catalog.createTempTable("tbl1", tempTable, ignoreIfExists = false)
- catalog.createTempTable("tbl4", tempTable, ignoreIfExists = false)
+ catalog.createTempTable("tbl1", tempTable, overrideIfExists = false)
+ catalog.createTempTable("tbl4", tempTable, overrideIfExists = false)
assert(catalog.listTables("db1").toSet ==
Set(TableIdentifier("tbl1"), TableIdentifier("tbl4")))
assert(catalog.listTables("db2").toSet ==
@@ -435,8 +435,8 @@ class SessionCatalogSuite extends SparkFunSuite {
test("list tables with pattern") {
val catalog = new SessionCatalog(newBasicCatalog())
val tempTable = Range(1, 10, 2, 10, Seq())
- catalog.createTempTable("tbl1", tempTable, ignoreIfExists = false)
- catalog.createTempTable("tbl4", tempTable, ignoreIfExists = false)
+ catalog.createTempTable("tbl1", tempTable, overrideIfExists = false)
+ catalog.createTempTable("tbl4", tempTable, overrideIfExists = false)
assert(catalog.listTables("db1", "*").toSet == catalog.listTables("db1").toSet)
assert(catalog.listTables("db2", "*").toSet == catalog.listTables("db2").toSet)
assert(catalog.listTables("db2", "tbl*").toSet ==
@@ -496,19 +496,25 @@ class SessionCatalogSuite extends SparkFunSuite {
val sessionCatalog = new SessionCatalog(externalCatalog)
assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part1, part2)))
sessionCatalog.dropPartitions(
- TableIdentifier("tbl2", Some("db2")), Seq(part1.spec), ignoreIfNotExists = false)
+ TableIdentifier("tbl2", Some("db2")),
+ Seq(part1.spec),
+ ignoreIfNotExists = false)
assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part2)))
// Drop partitions without explicitly specifying database
sessionCatalog.setCurrentDatabase("db2")
sessionCatalog.dropPartitions(
- TableIdentifier("tbl2"), Seq(part2.spec), ignoreIfNotExists = false)
+ TableIdentifier("tbl2"),
+ Seq(part2.spec),
+ ignoreIfNotExists = false)
assert(externalCatalog.listPartitions("db2", "tbl2").isEmpty)
// Drop multiple partitions at once
sessionCatalog.createPartitions(
TableIdentifier("tbl2", Some("db2")), Seq(part1, part2), ignoreIfExists = false)
assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part1, part2)))
sessionCatalog.dropPartitions(
- TableIdentifier("tbl2", Some("db2")), Seq(part1.spec, part2.spec), ignoreIfNotExists = false)
+ TableIdentifier("tbl2", Some("db2")),
+ Seq(part1.spec, part2.spec),
+ ignoreIfNotExists = false)
assert(externalCatalog.listPartitions("db2", "tbl2").isEmpty)
}
@@ -516,11 +522,15 @@ class SessionCatalogSuite extends SparkFunSuite {
val catalog = new SessionCatalog(newBasicCatalog())
intercept[AnalysisException] {
catalog.dropPartitions(
- TableIdentifier("tbl1", Some("does_not_exist")), Seq(), ignoreIfNotExists = false)
+ TableIdentifier("tbl1", Some("does_not_exist")),
+ Seq(),
+ ignoreIfNotExists = false)
}
intercept[AnalysisException] {
catalog.dropPartitions(
- TableIdentifier("does_not_exist", Some("db2")), Seq(), ignoreIfNotExists = false)
+ TableIdentifier("does_not_exist", Some("db2")),
+ Seq(),
+ ignoreIfNotExists = false)
}
}
@@ -528,10 +538,14 @@ class SessionCatalogSuite extends SparkFunSuite {
val catalog = new SessionCatalog(newBasicCatalog())
intercept[AnalysisException] {
catalog.dropPartitions(
- TableIdentifier("tbl2", Some("db2")), Seq(part3.spec), ignoreIfNotExists = false)
+ TableIdentifier("tbl2", Some("db2")),
+ Seq(part3.spec),
+ ignoreIfNotExists = false)
}
catalog.dropPartitions(
- TableIdentifier("tbl2", Some("db2")), Seq(part3.spec), ignoreIfNotExists = true)
+ TableIdentifier("tbl2", Some("db2")),
+ Seq(part3.spec),
+ ignoreIfNotExists = true)
}
test("get partition") {
@@ -658,78 +672,94 @@ class SessionCatalogSuite extends SparkFunSuite {
val externalCatalog = newEmptyCatalog()
val sessionCatalog = new SessionCatalog(externalCatalog)
sessionCatalog.createDatabase(newDb("mydb"), ignoreIfExists = false)
- sessionCatalog.createFunction(newFunc("myfunc", Some("mydb")))
+ sessionCatalog.createFunction(newFunc("myfunc", Some("mydb")), ignoreIfExists = false)
assert(externalCatalog.listFunctions("mydb", "*").toSet == Set("myfunc"))
// Create function without explicitly specifying database
sessionCatalog.setCurrentDatabase("mydb")
- sessionCatalog.createFunction(newFunc("myfunc2"))
+ sessionCatalog.createFunction(newFunc("myfunc2"), ignoreIfExists = false)
assert(externalCatalog.listFunctions("mydb", "*").toSet == Set("myfunc", "myfunc2"))
}
test("create function when database does not exist") {
val catalog = new SessionCatalog(newBasicCatalog())
intercept[AnalysisException] {
- catalog.createFunction(newFunc("func5", Some("does_not_exist")))
+ catalog.createFunction(
+ newFunc("func5", Some("does_not_exist")), ignoreIfExists = false)
}
}
test("create function that already exists") {
val catalog = new SessionCatalog(newBasicCatalog())
intercept[AnalysisException] {
- catalog.createFunction(newFunc("func1", Some("db2")))
+ catalog.createFunction(newFunc("func1", Some("db2")), ignoreIfExists = false)
}
+ catalog.createFunction(newFunc("func1", Some("db2")), ignoreIfExists = true)
}
test("create temp function") {
val catalog = new SessionCatalog(newBasicCatalog())
- val tempFunc1 = newFunc("temp1")
- val tempFunc2 = newFunc("temp2")
- catalog.createTempFunction(tempFunc1, ignoreIfExists = false)
- catalog.createTempFunction(tempFunc2, ignoreIfExists = false)
- assert(catalog.getTempFunction("temp1") == Some(tempFunc1))
- assert(catalog.getTempFunction("temp2") == Some(tempFunc2))
- assert(catalog.getTempFunction("temp3") == None)
+ val tempFunc1 = (e: Seq[Expression]) => e.head
+ val tempFunc2 = (e: Seq[Expression]) => e.last
+ val info1 = new ExpressionInfo("tempFunc1", "temp1")
+ val info2 = new ExpressionInfo("tempFunc2", "temp2")
+ catalog.createTempFunction("temp1", info1, tempFunc1, ignoreIfExists = false)
+ catalog.createTempFunction("temp2", info2, tempFunc2, ignoreIfExists = false)
+ val arguments = Seq(Literal(1), Literal(2), Literal(3))
+ assert(catalog.lookupFunction("temp1", arguments) === Literal(1))
+ assert(catalog.lookupFunction("temp2", arguments) === Literal(3))
+ // Temporary function does not exist.
+ intercept[AnalysisException] {
+ catalog.lookupFunction("temp3", arguments)
+ }
+ val tempFunc3 = (e: Seq[Expression]) => Literal(e.size)
+ val info3 = new ExpressionInfo("tempFunc3", "temp1")
// Temporary function already exists
intercept[AnalysisException] {
- catalog.createTempFunction(tempFunc1, ignoreIfExists = false)
+ catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = false)
}
// Temporary function is overridden
- val tempFunc3 = tempFunc1.copy(className = "something else")
- catalog.createTempFunction(tempFunc3, ignoreIfExists = true)
- assert(catalog.getTempFunction("temp1") == Some(tempFunc3))
+ catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = true)
+ assert(catalog.lookupFunction("temp1", arguments) === Literal(arguments.length))
}
test("drop function") {
val externalCatalog = newBasicCatalog()
val sessionCatalog = new SessionCatalog(externalCatalog)
assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1"))
- sessionCatalog.dropFunction(FunctionIdentifier("func1", Some("db2")))
+ sessionCatalog.dropFunction(
+ FunctionIdentifier("func1", Some("db2")), ignoreIfNotExists = false)
assert(externalCatalog.listFunctions("db2", "*").isEmpty)
// Drop function without explicitly specifying database
sessionCatalog.setCurrentDatabase("db2")
- sessionCatalog.createFunction(newFunc("func2", Some("db2")))
+ sessionCatalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false)
assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func2"))
- sessionCatalog.dropFunction(FunctionIdentifier("func2"))
+ sessionCatalog.dropFunction(FunctionIdentifier("func2"), ignoreIfNotExists = false)
assert(externalCatalog.listFunctions("db2", "*").isEmpty)
}
test("drop function when database/function does not exist") {
val catalog = new SessionCatalog(newBasicCatalog())
intercept[AnalysisException] {
- catalog.dropFunction(FunctionIdentifier("something", Some("does_not_exist")))
+ catalog.dropFunction(
+ FunctionIdentifier("something", Some("does_not_exist")), ignoreIfNotExists = false)
}
intercept[AnalysisException] {
- catalog.dropFunction(FunctionIdentifier("does_not_exist"))
+ catalog.dropFunction(FunctionIdentifier("does_not_exist"), ignoreIfNotExists = false)
}
+ catalog.dropFunction(FunctionIdentifier("does_not_exist"), ignoreIfNotExists = true)
}
test("drop temp function") {
val catalog = new SessionCatalog(newBasicCatalog())
- val tempFunc = newFunc("func1")
- catalog.createTempFunction(tempFunc, ignoreIfExists = false)
- assert(catalog.getTempFunction("func1") == Some(tempFunc))
+ val info = new ExpressionInfo("tempFunc", "func1")
+ val tempFunc = (e: Seq[Expression]) => e.head
+ catalog.createTempFunction("func1", info, tempFunc, ignoreIfExists = false)
+ val arguments = Seq(Literal(1), Literal(2), Literal(3))
+ assert(catalog.lookupFunction("func1", arguments) === Literal(1))
catalog.dropTempFunction("func1", ignoreIfNotExists = false)
- assert(catalog.getTempFunction("func1") == None)
+ intercept[AnalysisException] {
+ catalog.lookupFunction("func1", arguments)
+ }
intercept[AnalysisException] {
catalog.dropTempFunction("func1", ignoreIfNotExists = false)
}
@@ -738,132 +768,47 @@ class SessionCatalogSuite extends SparkFunSuite {
test("get function") {
val catalog = new SessionCatalog(newBasicCatalog())
- val expected = CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass)
- assert(catalog.getFunction(FunctionIdentifier("func1", Some("db2"))) == expected)
+ val expected =
+ CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass,
+ Seq.empty[(String, String)])
+ assert(catalog.getFunctionMetadata(FunctionIdentifier("func1", Some("db2"))) == expected)
// Get function without explicitly specifying database
catalog.setCurrentDatabase("db2")
- assert(catalog.getFunction(FunctionIdentifier("func1")) == expected)
+ assert(catalog.getFunctionMetadata(FunctionIdentifier("func1")) == expected)
}
test("get function when database/function does not exist") {
val catalog = new SessionCatalog(newBasicCatalog())
intercept[AnalysisException] {
- catalog.getFunction(FunctionIdentifier("func1", Some("does_not_exist")))
- }
- intercept[AnalysisException] {
- catalog.getFunction(FunctionIdentifier("does_not_exist", Some("db2")))
- }
- }
-
- test("get temp function") {
- val externalCatalog = newBasicCatalog()
- val sessionCatalog = new SessionCatalog(externalCatalog)
- val metastoreFunc = externalCatalog.getFunction("db2", "func1")
- val tempFunc = newFunc("func1").copy(className = "something weird")
- sessionCatalog.createTempFunction(tempFunc, ignoreIfExists = false)
- sessionCatalog.setCurrentDatabase("db2")
- // If a database is specified, we'll always return the function in that database
- assert(sessionCatalog.getFunction(FunctionIdentifier("func1", Some("db2"))) == metastoreFunc)
- // If no database is specified, we'll first return temporary functions
- assert(sessionCatalog.getFunction(FunctionIdentifier("func1")) == tempFunc)
- // Then, if no such temporary function exist, check the current database
- sessionCatalog.dropTempFunction("func1", ignoreIfNotExists = false)
- assert(sessionCatalog.getFunction(FunctionIdentifier("func1")) == metastoreFunc)
- }
-
- test("rename function") {
- val externalCatalog = newBasicCatalog()
- val sessionCatalog = new SessionCatalog(externalCatalog)
- val newName = "funcky"
- assert(sessionCatalog.getFunction(
- FunctionIdentifier("func1", Some("db2"))) == newFunc("func1", Some("db2")))
- assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1"))
- sessionCatalog.renameFunction(
- FunctionIdentifier("func1", Some("db2")), FunctionIdentifier(newName, Some("db2")))
- assert(sessionCatalog.getFunction(
- FunctionIdentifier(newName, Some("db2"))) == newFunc(newName, Some("db2")))
- assert(externalCatalog.listFunctions("db2", "*").toSet == Set(newName))
- // Rename function without explicitly specifying database
- sessionCatalog.setCurrentDatabase("db2")
- sessionCatalog.renameFunction(FunctionIdentifier(newName), FunctionIdentifier("func1"))
- assert(sessionCatalog.getFunction(
- FunctionIdentifier("func1")) == newFunc("func1", Some("db2")))
- assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1"))
- // Renaming "db2.func1" to "db1.func2" should fail because databases don't match
- intercept[AnalysisException] {
- sessionCatalog.renameFunction(
- FunctionIdentifier("func1", Some("db2")), FunctionIdentifier("func2", Some("db1")))
- }
- }
-
- test("rename function when database/function does not exist") {
- val catalog = new SessionCatalog(newBasicCatalog())
- intercept[AnalysisException] {
- catalog.renameFunction(
- FunctionIdentifier("func1", Some("does_not_exist")),
- FunctionIdentifier("func5", Some("does_not_exist")))
+ catalog.getFunctionMetadata(FunctionIdentifier("func1", Some("does_not_exist")))
}
intercept[AnalysisException] {
- catalog.renameFunction(
- FunctionIdentifier("does_not_exist", Some("db2")),
- FunctionIdentifier("x", Some("db2")))
+ catalog.getFunctionMetadata(FunctionIdentifier("does_not_exist", Some("db2")))
}
}
- test("rename temp function") {
- val externalCatalog = newBasicCatalog()
- val sessionCatalog = new SessionCatalog(externalCatalog)
- val tempFunc = newFunc("func1").copy(className = "something weird")
- sessionCatalog.createTempFunction(tempFunc, ignoreIfExists = false)
- sessionCatalog.setCurrentDatabase("db2")
- // If a database is specified, we'll always rename the function in that database
- sessionCatalog.renameFunction(
- FunctionIdentifier("func1", Some("db2")), FunctionIdentifier("func3", Some("db2")))
- assert(sessionCatalog.getTempFunction("func1") == Some(tempFunc))
- assert(sessionCatalog.getTempFunction("func3") == None)
- assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func3"))
- // If no database is specified, we'll first rename temporary functions
- sessionCatalog.createFunction(newFunc("func1", Some("db2")))
- sessionCatalog.renameFunction(FunctionIdentifier("func1"), FunctionIdentifier("func4"))
- assert(sessionCatalog.getTempFunction("func4") ==
- Some(tempFunc.copy(name = FunctionIdentifier("func4"))))
- assert(sessionCatalog.getTempFunction("func1") == None)
- assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1", "func3"))
- // Then, if no such temporary function exist, rename the function in the current database
- sessionCatalog.renameFunction(FunctionIdentifier("func1"), FunctionIdentifier("func5"))
- assert(sessionCatalog.getTempFunction("func5") == None)
- assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func3", "func5"))
- }
-
- test("alter function") {
- val catalog = new SessionCatalog(newBasicCatalog())
- assert(catalog.getFunction(FunctionIdentifier("func1", Some("db2"))).className == funcClass)
- catalog.alterFunction(newFunc("func1", Some("db2")).copy(className = "muhaha"))
- assert(catalog.getFunction(FunctionIdentifier("func1", Some("db2"))).className == "muhaha")
- // Alter function without explicitly specifying database
- catalog.setCurrentDatabase("db2")
- catalog.alterFunction(newFunc("func1").copy(className = "derpy"))
- assert(catalog.getFunction(FunctionIdentifier("func1")).className == "derpy")
- }
-
- test("alter function when database/function does not exist") {
+ test("lookup temp function") {
val catalog = new SessionCatalog(newBasicCatalog())
+ val info1 = new ExpressionInfo("tempFunc1", "func1")
+ val tempFunc1 = (e: Seq[Expression]) => e.head
+ catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false)
+ assert(catalog.lookupFunction("func1", Seq(Literal(1), Literal(2), Literal(3))) == Literal(1))
+ catalog.dropTempFunction("func1", ignoreIfNotExists = false)
intercept[AnalysisException] {
- catalog.alterFunction(newFunc("func5", Some("does_not_exist")))
- }
- intercept[AnalysisException] {
- catalog.alterFunction(newFunc("funcky", Some("db2")))
+ catalog.lookupFunction("func1", Seq(Literal(1), Literal(2), Literal(3)))
}
}
test("list functions") {
val catalog = new SessionCatalog(newBasicCatalog())
- val tempFunc1 = newFunc("func1").copy(className = "march")
- val tempFunc2 = newFunc("yes_me").copy(className = "april")
- catalog.createFunction(newFunc("func2", Some("db2")))
- catalog.createFunction(newFunc("not_me", Some("db2")))
- catalog.createTempFunction(tempFunc1, ignoreIfExists = false)
- catalog.createTempFunction(tempFunc2, ignoreIfExists = false)
+ val info1 = new ExpressionInfo("tempFunc1", "func1")
+ val info2 = new ExpressionInfo("tempFunc2", "yes_me")
+ val tempFunc1 = (e: Seq[Expression]) => e.head
+ val tempFunc2 = (e: Seq[Expression]) => e.last
+ catalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false)
+ catalog.createFunction(newFunc("not_me", Some("db2")), ignoreIfExists = false)
+ catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false)
+ catalog.createTempFunction("yes_me", info2, tempFunc2, ignoreIfExists = false)
assert(catalog.listFunctions("db1", "*").toSet ==
Set(FunctionIdentifier("func1"),
FunctionIdentifier("yes_me")))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index f6583bfe42..18752014ea 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -315,7 +315,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
val attr = AttributeReference("obj", ObjectType(encoder.clsTag.runtimeClass))()
val inputPlan = LocalRelation(attr)
val plan =
- Project(Alias(encoder.fromRowExpression, "obj")() :: Nil,
+ Project(Alias(encoder.deserializer, "obj")() :: Nil,
Project(encoder.namedExpressions,
inputPlan))
assertAnalysisSuccess(plan)
@@ -360,7 +360,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
|${encoder.schema.treeString}
|
|fromRow Expressions:
- |${boundEncoder.fromRowExpression.treeString}
+ |${boundEncoder.deserializer.treeString}
""".stripMargin)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 99e3b13ce8..2cf8ca7000 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -382,6 +382,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(InitCap(Literal("a b")), "A B")
checkEvaluation(InitCap(Literal(" a")), " A")
checkEvaluation(InitCap(Literal("the test")), "The Test")
+ checkEvaluation(InitCap(Literal("sParK")), "Spark")
// scalastyle:off
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
checkEvaluation(InitCap(Literal("世界")), "世界")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala
new file mode 100644
index 0000000000..b82cf8d169
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.scalatest.PrivateMethodTester
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.types.LongType
+
+class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with PrivateMethodTester {
+
+ test("time window is unevaluable") {
+ intercept[UnsupportedOperationException] {
+ evaluate(TimeWindow(Literal(10L), "1 second", "1 second", "0 second"))
+ }
+ }
+
+ private def checkErrorMessage(msg: String, value: String): Unit = {
+ val validDuration = "10 second"
+ val validTime = "5 second"
+ val e1 = intercept[IllegalArgumentException] {
+ TimeWindow(Literal(10L), value, validDuration, validTime).windowDuration
+ }
+ val e2 = intercept[IllegalArgumentException] {
+ TimeWindow(Literal(10L), validDuration, value, validTime).slideDuration
+ }
+ val e3 = intercept[IllegalArgumentException] {
+ TimeWindow(Literal(10L), validDuration, validDuration, value).startTime
+ }
+ Seq(e1, e2, e3).foreach { e =>
+ e.getMessage.contains(msg)
+ }
+ }
+
+ test("blank intervals throw exception") {
+ for (blank <- Seq(null, " ", "\n", "\t")) {
+ checkErrorMessage(
+ "The window duration, slide duration and start time cannot be null or blank.", blank)
+ }
+ }
+
+ test("invalid intervals throw exception") {
+ checkErrorMessage(
+ "did not correspond to a valid interval string.", "2 apples")
+ }
+
+ test("intervals greater than a month throws exception") {
+ checkErrorMessage(
+ "Intervals greater than or equal to a month is not supported (1 month).", "1 month")
+ }
+
+ test("interval strings work with and without 'interval' prefix and return microseconds") {
+ val validDuration = "10 second"
+ for ((text, seconds) <- Seq(
+ ("1 second", 1000000), // 1e6
+ ("1 minute", 60000000), // 6e7
+ ("2 hours", 7200000000L))) { // 72e9
+ assert(TimeWindow(Literal(10L), text, validDuration, "0 seconds").windowDuration === seconds)
+ assert(TimeWindow(Literal(10L), "interval " + text, validDuration, "0 seconds").windowDuration
+ === seconds)
+ }
+ }
+
+ private val parseExpression = PrivateMethod[Long]('parseExpression)
+
+ test("parse sql expression for duration in microseconds - string") {
+ val dur = TimeWindow.invokePrivate(parseExpression(Literal("5 seconds")))
+ assert(dur.isInstanceOf[Long])
+ assert(dur === 5000000)
+ }
+
+ test("parse sql expression for duration in microseconds - integer") {
+ val dur = TimeWindow.invokePrivate(parseExpression(Literal(100)))
+ assert(dur.isInstanceOf[Long])
+ assert(dur === 100)
+ }
+
+ test("parse sql expression for duration in microseconds - long") {
+ val dur = TimeWindow.invokePrivate(parseExpression(Literal.create(2 << 52, LongType)))
+ assert(dur.isInstanceOf[Long])
+ assert(dur === (2 << 52))
+ }
+
+ test("parse sql expression for duration in microseconds - invalid interval") {
+ intercept[IllegalArgumentException] {
+ TimeWindow.invokePrivate(parseExpression(Literal("2 apples")))
+ }
+ }
+
+ test("parse sql expression for duration in microseconds - invalid expression") {
+ intercept[AnalysisException] {
+ TimeWindow.invokePrivate(parseExpression(Rand(123)))
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala
index 9da1068e9c..f57b82bb96 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala
@@ -18,13 +18,20 @@
package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.util._
class CodeFormatterSuite extends SparkFunSuite {
def testCase(name: String)(input: String)(expected: String): Unit = {
test(name) {
- assert(CodeFormatter.format(input).trim === expected.trim)
+ if (CodeFormatter.format(input).trim !== expected.trim) {
+ fail(
+ s"""
+ |== FAIL: Formatted code doesn't match ===
+ |${sideBySide(CodeFormatter.format(input).trim, expected.trim).mkString("\n")}
+ """.stripMargin)
+ }
}
}
@@ -93,4 +100,50 @@ class CodeFormatterSuite extends SparkFunSuite {
|/* 004 */ c)
""".stripMargin
}
+
+ testCase("single line comments") {
+ """// This is a comment about class A { { { ( (
+ |class A {
+ |class body;
+ |}""".stripMargin
+ }{
+ """
+ |/* 001 */ // This is a comment about class A { { { ( (
+ |/* 002 */ class A {
+ |/* 003 */ class body;
+ |/* 004 */ }
+ """.stripMargin
+ }
+
+ testCase("single line comments /* */ ") {
+ """/** This is a comment about class A { { { ( ( */
+ |class A {
+ |class body;
+ |}""".stripMargin
+ }{
+ """
+ |/* 001 */ /** This is a comment about class A { { { ( ( */
+ |/* 002 */ class A {
+ |/* 003 */ class body;
+ |/* 004 */ }
+ """.stripMargin
+ }
+
+ testCase("multi-line comments") {
+ """ /* This is a comment about
+ |class A {
+ |class body; ...*/
+ |class A {
+ |class body;
+ |}""".stripMargin
+ }{
+ """
+ |/* 001 */ /* This is a comment about
+ |/* 002 */ class A {
+ |/* 003 */ class body; ...*/
+ |/* 004 */ class A {
+ |/* 005 */ class body;
+ |/* 006 */ }
+ """.stripMargin
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala
new file mode 100644
index 0000000000..7cd038570b
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+
+class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("AnalysisNodes", Once,
+ EliminateSubqueryAliases) ::
+ Batch("Constant Folding", FixedPoint(50),
+ NullPropagation,
+ ConstantFolding,
+ BooleanSimplification,
+ BinaryComparisonSimplification,
+ PruneFilters) :: Nil
+ }
+
+ val nullableRelation = LocalRelation('a.int.withNullability(true))
+ val nonNullableRelation = LocalRelation('a.int.withNullability(false))
+
+ test("Preserve nullable exprs in general") {
+ for (e <- Seq('a === 'a, 'a <= 'a, 'a >= 'a, 'a < 'a, 'a > 'a)) {
+ val plan = nullableRelation.where(e).analyze
+ val actual = Optimize.execute(plan)
+ val correctAnswer = plan
+ comparePlans(actual, correctAnswer)
+ }
+ }
+
+ test("Preserve non-deterministic exprs") {
+ val plan = nonNullableRelation
+ .where(Rand(0) === Rand(0) && Rand(1) <=> Rand(1)).analyze
+ val actual = Optimize.execute(plan)
+ val correctAnswer = plan
+ comparePlans(actual, correctAnswer)
+ }
+
+ test("Nullable Simplification Primitive: <=>") {
+ val plan = nullableRelation.select('a <=> 'a).analyze
+ val actual = Optimize.execute(plan)
+ val correctAnswer = nullableRelation.select(Alias(TrueLiteral, "(a <=> a)")()).analyze
+ comparePlans(actual, correctAnswer)
+ }
+
+ test("Non-Nullable Simplification Primitive") {
+ val plan = nonNullableRelation
+ .select('a === 'a, 'a <=> 'a, 'a <= 'a, 'a >= 'a, 'a < 'a, 'a > 'a).analyze
+ val actual = Optimize.execute(plan)
+ val correctAnswer = nonNullableRelation
+ .select(
+ Alias(TrueLiteral, "(a = a)")(),
+ Alias(TrueLiteral, "(a <=> a)")(),
+ Alias(TrueLiteral, "(a <= a)")(),
+ Alias(TrueLiteral, "(a >= a)")(),
+ Alias(FalseLiteral, "(a < a)")(),
+ Alias(FalseLiteral, "(a > a)")())
+ .analyze
+ comparePlans(actual, correctAnswer)
+ }
+
+ test("Expression Normalization") {
+ val plan = nonNullableRelation.where(
+ 'a * Literal(100) + Pi() === Pi() + Literal(100) * 'a &&
+ DateAdd(CurrentDate(), 'a + Literal(2)) <= DateAdd(CurrentDate(), Literal(2) + 'a))
+ .analyze
+ val actual = Optimize.execute(plan)
+ val correctAnswer = nonNullableRelation.analyze
+ comparePlans(actual, correctAnswer)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
index e2c76b700f..8147d06969 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
@@ -140,8 +140,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
private val caseInsensitiveConf = new SimpleCatalystConf(false)
private val caseInsensitiveAnalyzer = new Analyzer(
- new SessionCatalog(new InMemoryCatalog, caseInsensitiveConf),
- EmptyFunctionRegistry,
+ new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, caseInsensitiveConf),
caseInsensitiveConf)
test("(a && b) || (a && c) => a && (b || c) when case insensitive") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index 2248e03b2f..52b574c0e6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -34,7 +34,7 @@ class ColumnPruningSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("Column pruning", FixedPoint(100),
- PushPredicateThroughProject,
+ PushDownPredicate,
ColumnPruning,
CollapseProject) :: Nil
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala
index 3824c67563..8c92ad82ac 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala
@@ -29,8 +29,8 @@ import org.apache.spark.sql.catalyst.rules._
class EliminateSortsSuite extends PlanTest {
val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true, orderByOrdinal = false)
- val catalog = new SessionCatalog(new InMemoryCatalog, conf)
- val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf)
+ val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
+ val analyzer = new Analyzer(catalog, conf)
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index b84ae7c5bb..df7529d83f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -33,14 +33,12 @@ class FilterPushdownSuite extends PlanTest {
val batches =
Batch("Subqueries", Once,
EliminateSubqueryAliases) ::
- Batch("Filter Pushdown", Once,
+ Batch("Filter Pushdown", FixedPoint(10),
SamplePushDown,
CombineFilters,
- PushPredicateThroughProject,
+ PushDownPredicate,
BooleanSimplification,
PushPredicateThroughJoin,
- PushPredicateThroughGenerate,
- PushPredicateThroughAggregate,
CollapseProject) :: Nil
}
@@ -620,8 +618,8 @@ class FilterPushdownSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
- .select('a, 'b)
.where('a === 3)
+ .select('a, 'b)
.groupBy('a)('a, count('b) as 'c)
.where('c === 2L)
.analyze
@@ -638,8 +636,8 @@ class FilterPushdownSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
- .select('a, 'b)
.where('a + 1 < 3)
+ .select('a, 'b)
.groupBy('a)(('a + 1) as 'aa, count('b) as 'c)
.where('c === 2L || 'aa > 4)
.analyze
@@ -656,8 +654,8 @@ class FilterPushdownSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
- .select('a, 'b)
.where("s" === "s")
+ .select('a, 'b)
.groupBy('a)('a, count('b) as 'c, "s" as 'd)
.where('c === 2L)
.analyze
@@ -681,4 +679,68 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+
+ test("broadcast hint") {
+ val originalQuery = BroadcastHint(testRelation)
+ .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = BroadcastHint(testRelation.where('a === 2L))
+ .where('b + Rand(10).as("rnd") === 3)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("union") {
+ val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
+
+ val originalQuery = Union(Seq(testRelation, testRelation2))
+ .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = Union(Seq(
+ testRelation.where('a === 2L),
+ testRelation2.where('d === 2L)))
+ .where('b + Rand(10).as("rnd") === 3)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("intersect") {
+ val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
+
+ val originalQuery = Intersect(testRelation, testRelation2)
+ .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = Intersect(
+ testRelation.where('a === 2L),
+ testRelation2.where('d === 2L))
+ .where('b + Rand(10).as("rnd") === 3)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("except") {
+ val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
+
+ val originalQuery = Except(testRelation, testRelation2)
+ .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = Except(
+ testRelation.where('a === 2L),
+ testRelation2)
+ .where('b + Rand(10).as("rnd") === 3)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
index e2f8146bee..c1ebf8b09e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
@@ -36,12 +36,10 @@ class JoinOptimizationSuite extends PlanTest {
EliminateSubqueryAliases) ::
Batch("Filter Pushdown", FixedPoint(100),
CombineFilters,
- PushPredicateThroughProject,
+ PushDownPredicate,
BooleanSimplification,
ReorderJoin,
PushPredicateThroughJoin,
- PushPredicateThroughGenerate,
- PushPredicateThroughAggregate,
ColumnPruning,
CollapseProject) :: Nil
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala
index 741bc113cf..fdde89d079 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala
@@ -61,6 +61,20 @@ class LikeSimplificationSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+ test("simplify Like into startsWith and EndsWith") {
+ val originalQuery =
+ testRelation
+ .where(('a like "abc\\%def") || ('a like "abc%def"))
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer = testRelation
+ .where(('a like "abc\\%def") ||
+ (Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def"))))
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
test("simplify Like into Contains") {
val originalQuery =
testRelation
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala
index 7e3da6bea7..6e5672ddc3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala
@@ -23,21 +23,21 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
/**
- * This is a test for SPARK-7727 if the Optimizer is kept being extendable
- */
+ * This is a test for SPARK-7727 if the Optimizer is kept being extendable
+ */
class OptimizerExtendableSuite extends SparkFunSuite {
/**
- * Dummy rule for test batches
- */
+ * Dummy rule for test batches
+ */
object DummyRule extends Rule[LogicalPlan] {
def apply(p: LogicalPlan): LogicalPlan = p
}
/**
- * This class represents a dummy extended optimizer that takes the batches of the
- * Optimizer and adds custom ones.
- */
+ * This class represents a dummy extended optimizer that takes the batches of the
+ * Optimizer and adds custom ones.
+ */
class ExtendedOptimizer extends Optimizer {
// rules set to DummyRule, would not be executed anyways
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala
index 14fb72a8a3..d8cfec5391 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala
@@ -34,7 +34,7 @@ class PruneFiltersSuite extends PlanTest {
Batch("Filter Pushdown and Pruning", Once,
CombineFilters,
PruneFilters,
- PushPredicateThroughProject,
+ PushDownPredicate,
PushPredicateThroughJoin) :: Nil
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
index d436b627f6..c02fec3085 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
-import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.types.{IntegerType, NullType}
class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
@@ -41,6 +41,7 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
private val trueBranch = (TrueLiteral, Literal(5))
private val normalBranch = (NonFoldableLiteral(true), Literal(10))
private val unreachableBranch = (FalseLiteral, Literal(20))
+ private val nullBranch = (Literal.create(null, NullType), Literal(30))
test("simplify if") {
assertEquivalent(
@@ -50,18 +51,22 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
assertEquivalent(
If(FalseLiteral, Literal(10), Literal(20)),
Literal(20))
+
+ assertEquivalent(
+ If(Literal.create(null, NullType), Literal(10), Literal(20)),
+ Literal(20))
}
test("remove unreachable branches") {
// i.e. removing branches whose conditions are always false
assertEquivalent(
- CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: Nil, None),
+ CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None),
CaseWhen(normalBranch :: Nil, None))
}
test("remove entire CaseWhen if only the else branch is reachable") {
assertEquivalent(
- CaseWhen(unreachableBranch :: unreachableBranch :: Nil, Some(Literal(30))),
+ CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: Nil, Some(Literal(30))),
Literal(30))
assertEquivalent(
@@ -71,12 +76,13 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
test("remove entire CaseWhen if the first branch is always true") {
assertEquivalent(
- CaseWhen(trueBranch :: normalBranch :: Nil, None),
+ CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None),
Literal(5))
// Test branch elimination and simplification in combination
assertEquivalent(
- CaseWhen(unreachableBranch :: unreachableBranch:: trueBranch :: normalBranch :: Nil, None),
+ CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch
+ :: Nil, None),
Literal(5))
// Make sure this doesn't trigger if there is a non-foldable branch before the true branch
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala
new file mode 100644
index 0000000000..1fae64e3bc
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types.BooleanType
+
+class TypedFilterOptimizationSuite extends PlanTest {
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("EliminateSerialization", FixedPoint(50),
+ EliminateSerialization) ::
+ Batch("EmbedSerializerInFilter", FixedPoint(50),
+ EmbedSerializerInFilter) :: Nil
+ }
+
+ implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]()
+
+ test("back to back filter") {
+ val input = LocalRelation('_1.int, '_2.int)
+ val f1 = (i: (Int, Int)) => i._1 > 0
+ val f2 = (i: (Int, Int)) => i._2 > 0
+
+ val query = input.filter(f1).filter(f2).analyze
+
+ val optimized = Optimize.execute(query)
+
+ val expected = input.deserialize[(Int, Int)]
+ .where(callFunction(f1, BooleanType, 'obj))
+ .select('obj.as("obj"))
+ .where(callFunction(f2, BooleanType, 'obj))
+ .serialize[(Int, Int)].analyze
+
+ comparePlans(optimized, expected)
+ }
+
+ test("embed deserializer in filter condition if there is only one filter") {
+ val input = LocalRelation('_1.int, '_2.int)
+ val f = (i: (Int, Int)) => i._1 > 0
+
+ val query = input.filter(f).analyze
+
+ val optimized = Optimize.execute(query)
+
+ val deserializer = UnresolvedDeserializer(encoderFor[(Int, Int)].deserializer)
+ val condition = callFunction(f, BooleanType, deserializer)
+ val expected = input.where(condition).analyze
+
+ comparePlans(optimized, expected)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala
deleted file mode 100644
index c068e895b6..0000000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala
+++ /dev/null
@@ -1,243 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.parser
-
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.unsafe.types.CalendarInterval
-
-class CatalystQlSuite extends PlanTest {
- val parser = new CatalystQl()
-
- test("test case insensitive") {
- val result = Project(UnresolvedAlias(Literal(1)):: Nil, OneRowRelation)
- assert(result === parser.parsePlan("seLect 1"))
- assert(result === parser.parsePlan("select 1"))
- assert(result === parser.parsePlan("SELECT 1"))
- }
-
- test("test NOT operator with comparison operations") {
- val parsed = parser.parsePlan("SELECT NOT TRUE > TRUE")
- val expected = Project(
- UnresolvedAlias(
- Not(
- GreaterThan(Literal(true), Literal(true)))
- ) :: Nil,
- OneRowRelation)
- comparePlans(parsed, expected)
- }
-
- test("test Union Distinct operator") {
- val parsed1 = parser.parsePlan("SELECT * FROM t0 UNION SELECT * FROM t1")
- val parsed2 = parser.parsePlan("SELECT * FROM t0 UNION DISTINCT SELECT * FROM t1")
- val expected =
- Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
- SubqueryAlias("u_1",
- Distinct(
- Union(
- Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
- UnresolvedRelation(TableIdentifier("t0"), None)),
- Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
- UnresolvedRelation(TableIdentifier("t1"), None))))))
- comparePlans(parsed1, expected)
- comparePlans(parsed2, expected)
- }
-
- test("test Union All operator") {
- val parsed = parser.parsePlan("SELECT * FROM t0 UNION ALL SELECT * FROM t1")
- val expected =
- Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
- SubqueryAlias("u_1",
- Union(
- Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
- UnresolvedRelation(TableIdentifier("t0"), None)),
- Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
- UnresolvedRelation(TableIdentifier("t1"), None)))))
- comparePlans(parsed, expected)
- }
-
- test("support hive interval literal") {
- def checkInterval(sql: String, result: CalendarInterval): Unit = {
- val parsed = parser.parsePlan(sql)
- val expected = Project(
- UnresolvedAlias(
- Literal(result)
- ) :: Nil,
- OneRowRelation)
- comparePlans(parsed, expected)
- }
-
- def checkYearMonth(lit: String): Unit = {
- checkInterval(
- s"SELECT INTERVAL '$lit' YEAR TO MONTH",
- CalendarInterval.fromYearMonthString(lit))
- }
-
- def checkDayTime(lit: String): Unit = {
- checkInterval(
- s"SELECT INTERVAL '$lit' DAY TO SECOND",
- CalendarInterval.fromDayTimeString(lit))
- }
-
- def checkSingleUnit(lit: String, unit: String): Unit = {
- checkInterval(
- s"SELECT INTERVAL '$lit' $unit",
- CalendarInterval.fromSingleUnitString(unit, lit))
- }
-
- checkYearMonth("123-10")
- checkYearMonth("496-0")
- checkYearMonth("-2-3")
- checkYearMonth("-123-0")
-
- checkDayTime("99 11:22:33.123456789")
- checkDayTime("-99 11:22:33.123456789")
- checkDayTime("10 9:8:7.123456789")
- checkDayTime("1 0:0:0")
- checkDayTime("-1 0:0:0")
- checkDayTime("1 0:0:1")
-
- for (unit <- Seq("year", "month", "day", "hour", "minute", "second")) {
- checkSingleUnit("7", unit)
- checkSingleUnit("-7", unit)
- checkSingleUnit("0", unit)
- }
-
- checkSingleUnit("13.123456789", "second")
- checkSingleUnit("-13.123456789", "second")
- }
-
- test("support scientific notation") {
- def assertRight(input: String, output: Double): Unit = {
- val parsed = parser.parsePlan("SELECT " + input)
- val expected = Project(
- UnresolvedAlias(
- Literal(output)
- ) :: Nil,
- OneRowRelation)
- comparePlans(parsed, expected)
- }
-
- assertRight("9.0e1", 90)
- assertRight(".9e+2", 90)
- assertRight("0.9e+2", 90)
- assertRight("900e-1", 90)
- assertRight("900.0E-1", 90)
- assertRight("9.e+1", 90)
-
- intercept[AnalysisException](parser.parsePlan("SELECT .e3"))
- }
-
- test("parse expressions") {
- compareExpressions(
- parser.parseExpression("prinln('hello', 'world')"),
- UnresolvedFunction(
- "prinln", Literal("hello") :: Literal("world") :: Nil, false))
-
- compareExpressions(
- parser.parseExpression("1 + r.r As q"),
- Alias(Add(Literal(1), UnresolvedAttribute("r.r")), "q")())
-
- compareExpressions(
- parser.parseExpression("1 - f('o', o(bar))"),
- Subtract(Literal(1),
- UnresolvedFunction("f",
- Literal("o") ::
- UnresolvedFunction("o", UnresolvedAttribute("bar") :: Nil, false) ::
- Nil, false)))
-
- intercept[AnalysisException](parser.parseExpression("1 - f('o', o(bar)) hello * world"))
- }
-
- test("table identifier") {
- assert(TableIdentifier("q") === parser.parseTableIdentifier("q"))
- assert(TableIdentifier("q", Some("d")) === parser.parseTableIdentifier("d.q"))
- intercept[AnalysisException](parser.parseTableIdentifier(""))
- intercept[AnalysisException](parser.parseTableIdentifier("d.q.g"))
- }
-
- test("parse union/except/intersect") {
- parser.parsePlan("select * from t1 union all select * from t2")
- parser.parsePlan("select * from t1 union distinct select * from t2")
- parser.parsePlan("select * from t1 union select * from t2")
- parser.parsePlan("select * from t1 except select * from t2")
- parser.parsePlan("select * from t1 intersect select * from t2")
- parser.parsePlan("(select * from t1) union all (select * from t2)")
- parser.parsePlan("(select * from t1) union distinct (select * from t2)")
- parser.parsePlan("(select * from t1) union (select * from t2)")
- parser.parsePlan("select * from ((select * from t1) union (select * from t2)) t")
- }
-
- test("window function: better support of parentheses") {
- parser.parsePlan("select sum(product + 1) over (partition by ((1) + (product / 2)) " +
- "order by 2) from windowData")
- parser.parsePlan("select sum(product + 1) over (partition by (1 + (product / 2)) " +
- "order by 2) from windowData")
- parser.parsePlan("select sum(product + 1) over (partition by ((product / 2) + 1) " +
- "order by 2) from windowData")
-
- parser.parsePlan("select sum(product + 1) over (partition by ((product) + (1)) order by 2) " +
- "from windowData")
- parser.parsePlan("select sum(product + 1) over (partition by ((product) + 1) order by 2) " +
- "from windowData")
- parser.parsePlan("select sum(product + 1) over (partition by (product + (1)) order by 2) " +
- "from windowData")
- }
-
- test("very long AND/OR expression") {
- val equals = (1 to 1000).map(x => s"$x == $x")
- val expr = parser.parseExpression(equals.mkString(" AND "))
- assert(expr.isInstanceOf[And])
- assert(expr.collect( { case EqualTo(_, _) => true } ).size == 1000)
-
- val expr2 = parser.parseExpression(equals.mkString(" OR "))
- assert(expr2.isInstanceOf[Or])
- assert(expr2.collect( { case EqualTo(_, _) => true } ).size == 1000)
- }
-
- test("subquery") {
- parser.parsePlan("select (select max(b) from s) ss from t")
- parser.parsePlan("select * from t where a = (select b from s)")
- parser.parsePlan("select * from t group by g having a > (select b from s)")
- }
-
- test("using clause in JOIN") {
- // Tests parsing of using clause for different join types.
- parser.parsePlan("select * from t1 join t2 using (c1)")
- parser.parsePlan("select * from t1 join t2 using (c1, c2)")
- parser.parsePlan("select * from t1 left join t2 using (c1, c2)")
- parser.parsePlan("select * from t1 right join t2 using (c1, c2)")
- parser.parsePlan("select * from t1 full outer join t2 using (c1, c2)")
- parser.parsePlan("select * from t1 join t2 using (c1) join t3 using (c2)")
- // Tests errors
- // (1) Empty using clause
- // (2) Qualified columns in using
- // (3) Both on and using clause
- var error = intercept[AnalysisException](parser.parsePlan("select * from t1 join t2 using ()"))
- assert(error.message.contains("cannot recognize input near ')'"))
- error = intercept[AnalysisException](parser.parsePlan("select * from t1 join t2 using (t1.c1)"))
- assert(error.message.contains("mismatched input '.'"))
- error = intercept[AnalysisException](parser.parsePlan("select * from t1" +
- " join t2 using (c1) on t1.c1 = t2.c1"))
- assert(error.message.contains("missing EOF at 'on' near ')'"))
- }
-}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
index 7d3608033b..07b89cb61f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
@@ -20,17 +20,21 @@ package org.apache.spark.sql.catalyst.parser
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._
-class DataTypeParserSuite extends SparkFunSuite {
+abstract class AbstractDataTypeParserSuite extends SparkFunSuite {
+
+ def parse(sql: String): DataType
def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = {
test(s"parse ${dataTypeString.replace("\n", "")}") {
- assert(DataTypeParser.parse(dataTypeString) === expectedDataType)
+ assert(parse(dataTypeString) === expectedDataType)
}
}
+ def intercept(sql: String)
+
def unsupported(dataTypeString: String): Unit = {
test(s"$dataTypeString is not supported") {
- intercept[DataTypeException](DataTypeParser.parse(dataTypeString))
+ intercept(dataTypeString)
}
}
@@ -97,13 +101,6 @@ class DataTypeParserSuite extends SparkFunSuite {
StructField("arrAy", ArrayType(DoubleType, true), true) ::
StructField("anotherArray", ArrayType(StringType, true), true) :: Nil)
)
- // A column name can be a reserved word in our DDL parser and SqlParser.
- checkDataType(
- "Struct<TABLE: string, CASE:boolean>",
- StructType(
- StructField("TABLE", StringType, true) ::
- StructField("CASE", BooleanType, true) :: Nil)
- )
// Use backticks to quote column names having special characters.
checkDataType(
"struct<`x+y`:int, `!@#$%^&*()`:string, `1_2.345<>:\"`:varchar(20)>",
@@ -118,6 +115,43 @@ class DataTypeParserSuite extends SparkFunSuite {
unsupported("it is not a data type")
unsupported("struct<x+y: int, 1.1:timestamp>")
unsupported("struct<x: int")
+}
+
+class DataTypeParserSuite extends AbstractDataTypeParserSuite {
+ override def intercept(sql: String): Unit =
+ intercept[DataTypeException](DataTypeParser.parse(sql))
+
+ override def parse(sql: String): DataType =
+ DataTypeParser.parse(sql)
+
+ // A column name can be a reserved word in our DDL parser and SqlParser.
+ checkDataType(
+ "Struct<TABLE: string, CASE:boolean>",
+ StructType(
+ StructField("TABLE", StringType, true) ::
+ StructField("CASE", BooleanType, true) :: Nil)
+ )
+
unsupported("struct<x int, y string>")
+
unsupported("struct<`x``y` int>")
}
+
+class CatalystQlDataTypeParserSuite extends AbstractDataTypeParserSuite {
+ override def intercept(sql: String): Unit =
+ intercept[ParseException](CatalystSqlParser.parseDataType(sql))
+
+ override def parse(sql: String): DataType =
+ CatalystSqlParser.parseDataType(sql)
+
+ // A column name can be a reserved word in our DDL parser and SqlParser.
+ unsupported("Struct<TABLE: string, CASE:boolean>")
+
+ checkDataType(
+ "struct<x int, y string>",
+ (new StructType).add("x", IntegerType).add("y", StringType))
+
+ checkDataType(
+ "struct<`x``y` int>",
+ (new StructType).add("x`y", IntegerType))
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala
new file mode 100644
index 0000000000..db96bfb652
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser
+
+import org.apache.spark.SparkFunSuite
+
+/**
+ * Test various parser errors.
+ */
+class ErrorParserSuite extends SparkFunSuite {
+ def intercept(sql: String, line: Int, startPosition: Int, messages: String*): Unit = {
+ val e = intercept[ParseException](CatalystSqlParser.parsePlan(sql))
+
+ // Check position.
+ assert(e.line.isDefined)
+ assert(e.line.get === line)
+ assert(e.startPosition.isDefined)
+ assert(e.startPosition.get === startPosition)
+
+ // Check messages.
+ val error = e.getMessage
+ messages.foreach { message =>
+ assert(error.contains(message))
+ }
+ }
+
+ test("no viable input") {
+ intercept("select from tbl", 1, 7, "no viable alternative at input", "-------^^^")
+ intercept("select\nfrom tbl", 2, 0, "no viable alternative at input", "^^^")
+ intercept("select ((r + 1) ", 1, 16, "no viable alternative at input", "----------------^^^")
+ }
+
+ test("extraneous input") {
+ intercept("select 1 1", 1, 9, "extraneous input '1' expecting", "---------^^^")
+ intercept("select *\nfrom r as q t", 2, 12, "extraneous input", "------------^^^")
+ }
+
+ test("mismatched input") {
+ intercept("select * from r order by q from t", 1, 27,
+ "mismatched input",
+ "---------------------------^^^")
+ intercept("select *\nfrom r\norder by q\nfrom t", 4, 0, "mismatched input", "^^^")
+ }
+
+ test("semantic errors") {
+ intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0,
+ "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported",
+ "^^^")
+ intercept("select * from r where a in (select * from t)", 1, 24,
+ "IN with a Sub-query is currently not supported",
+ "------------------------^^^")
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
new file mode 100644
index 0000000000..6f40ec67ec
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
@@ -0,0 +1,497 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser
+
+import java.sql.{Date, Timestamp}
+
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
+
+/**
+ * Test basic expression parsing. If a type of expression is supported it should be tested here.
+ *
+ * Please note that some of the expressions test don't have to be sound expressions, only their
+ * structure needs to be valid. Unsound expressions should be caught by the Analyzer or
+ * CheckAnalysis classes.
+ */
+class ExpressionParserSuite extends PlanTest {
+ import CatalystSqlParser._
+ import org.apache.spark.sql.catalyst.dsl.expressions._
+ import org.apache.spark.sql.catalyst.dsl.plans._
+
+ def assertEqual(sqlCommand: String, e: Expression): Unit = {
+ compareExpressions(parseExpression(sqlCommand), e)
+ }
+
+ def intercept(sqlCommand: String, messages: String*): Unit = {
+ val e = intercept[ParseException](parseExpression(sqlCommand))
+ messages.foreach { message =>
+ assert(e.message.contains(message))
+ }
+ }
+
+ test("star expressions") {
+ // Global Star
+ assertEqual("*", UnresolvedStar(None))
+
+ // Targeted Star
+ assertEqual("a.b.*", UnresolvedStar(Option(Seq("a", "b"))))
+ }
+
+ // NamedExpression (Alias/Multialias)
+ test("named expressions") {
+ // No Alias
+ val r0 = 'a
+ assertEqual("a", r0)
+
+ // Single Alias.
+ val r1 = 'a as "b"
+ assertEqual("a as b", r1)
+ assertEqual("a b", r1)
+
+ // Multi-Alias
+ assertEqual("a as (b, c)", MultiAlias('a, Seq("b", "c")))
+ assertEqual("a() (b, c)", MultiAlias('a.function(), Seq("b", "c")))
+
+ // Numeric literals without a space between the literal qualifier and the alias, should not be
+ // interpreted as such. An unresolved reference should be returned instead.
+ // TODO add the JIRA-ticket number.
+ assertEqual("1SL", Symbol("1SL"))
+
+ // Aliased star is allowed.
+ assertEqual("a.* b", UnresolvedStar(Option(Seq("a"))) as 'b)
+ }
+
+ test("binary logical expressions") {
+ // And
+ assertEqual("a and b", 'a && 'b)
+
+ // Or
+ assertEqual("a or b", 'a || 'b)
+
+ // Combination And/Or check precedence
+ assertEqual("a and b or c and d", ('a && 'b) || ('c && 'd))
+ assertEqual("a or b or c and d", 'a || 'b || ('c && 'd))
+
+ // Multiple AND/OR get converted into a balanced tree
+ assertEqual("a or b or c or d or e or f", (('a || 'b) || 'c) || (('d || 'e) || 'f))
+ assertEqual("a and b and c and d and e and f", (('a && 'b) && 'c) && (('d && 'e) && 'f))
+ }
+
+ test("long binary logical expressions") {
+ def testVeryBinaryExpression(op: String, clazz: Class[_]): Unit = {
+ val sql = (1 to 1000).map(x => s"$x == $x").mkString(op)
+ val e = parseExpression(sql)
+ assert(e.collect { case _: EqualTo => true }.size === 1000)
+ assert(e.collect { case x if clazz.isInstance(x) => true }.size === 999)
+ }
+ testVeryBinaryExpression(" AND ", classOf[And])
+ testVeryBinaryExpression(" OR ", classOf[Or])
+ }
+
+ test("not expressions") {
+ assertEqual("not a", !'a)
+ assertEqual("!a", !'a)
+ assertEqual("not true > true", Not(GreaterThan(true, true)))
+ }
+
+ test("exists expression") {
+ intercept("exists (select 1 from b where b.x = a.x)", "EXISTS clauses are not supported")
+ }
+
+ test("comparison expressions") {
+ assertEqual("a = b", 'a === 'b)
+ assertEqual("a == b", 'a === 'b)
+ assertEqual("a <=> b", 'a <=> 'b)
+ assertEqual("a <> b", 'a =!= 'b)
+ assertEqual("a != b", 'a =!= 'b)
+ assertEqual("a < b", 'a < 'b)
+ assertEqual("a <= b", 'a <= 'b)
+ assertEqual("a > b", 'a > 'b)
+ assertEqual("a >= b", 'a >= 'b)
+ }
+
+ test("between expressions") {
+ assertEqual("a between b and c", 'a >= 'b && 'a <= 'c)
+ assertEqual("a not between b and c", !('a >= 'b && 'a <= 'c))
+ }
+
+ test("in expressions") {
+ assertEqual("a in (b, c, d)", 'a in ('b, 'c, 'd))
+ assertEqual("a not in (b, c, d)", !('a in ('b, 'c, 'd)))
+ }
+
+ test("in sub-query") {
+ intercept("a in (select b from c)", "IN with a Sub-query is currently not supported")
+ }
+
+ test("like expressions") {
+ assertEqual("a like 'pattern%'", 'a like "pattern%")
+ assertEqual("a not like 'pattern%'", !('a like "pattern%"))
+ assertEqual("a rlike 'pattern%'", 'a rlike "pattern%")
+ assertEqual("a not rlike 'pattern%'", !('a rlike "pattern%"))
+ assertEqual("a regexp 'pattern%'", 'a rlike "pattern%")
+ assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%"))
+ }
+
+ test("is null expressions") {
+ assertEqual("a is null", 'a.isNull)
+ assertEqual("a is not null", 'a.isNotNull)
+ assertEqual("a = b is null", ('a === 'b).isNull)
+ assertEqual("a = b is not null", ('a === 'b).isNotNull)
+ }
+
+ test("binary arithmetic expressions") {
+ // Simple operations
+ assertEqual("a * b", 'a * 'b)
+ assertEqual("a / b", 'a / 'b)
+ assertEqual("a DIV b", ('a / 'b).cast(LongType))
+ assertEqual("a % b", 'a % 'b)
+ assertEqual("a + b", 'a + 'b)
+ assertEqual("a - b", 'a - 'b)
+ assertEqual("a & b", 'a & 'b)
+ assertEqual("a ^ b", 'a ^ 'b)
+ assertEqual("a | b", 'a | 'b)
+
+ // Check precedences
+ assertEqual(
+ "a * t | b ^ c & d - e + f % g DIV h / i * k",
+ 'a * 't | ('b ^ ('c & ('d - 'e + (('f % 'g / 'h).cast(LongType) / 'i * 'k)))))
+ }
+
+ test("unary arithmetic expressions") {
+ assertEqual("+a", 'a)
+ assertEqual("-a", -'a)
+ assertEqual("~a", ~'a)
+ assertEqual("-+~~a", -(~(~'a)))
+ }
+
+ test("cast expressions") {
+ // Note that DataType parsing is tested elsewhere.
+ assertEqual("cast(a as int)", 'a.cast(IntegerType))
+ assertEqual("cast(a as timestamp)", 'a.cast(TimestampType))
+ assertEqual("cast(a as array<int>)", 'a.cast(ArrayType(IntegerType)))
+ assertEqual("cast(cast(a as int) as long)", 'a.cast(IntegerType).cast(LongType))
+ }
+
+ test("function expressions") {
+ assertEqual("foo()", 'foo.function())
+ assertEqual("foo.bar()", Symbol("foo.bar").function())
+ assertEqual("foo(*)", 'foo.function(star()))
+ assertEqual("count(*)", 'count.function(1))
+ assertEqual("foo(a, b)", 'foo.function('a, 'b))
+ assertEqual("foo(all a, b)", 'foo.function('a, 'b))
+ assertEqual("foo(distinct a, b)", 'foo.distinctFunction('a, 'b))
+ assertEqual("grouping(distinct a, b)", 'grouping.distinctFunction('a, 'b))
+ assertEqual("`select`(all a, b)", 'select.function('a, 'b))
+ }
+
+ test("window function expressions") {
+ val func = 'foo.function(star())
+ def windowed(
+ partitioning: Seq[Expression] = Seq.empty,
+ ordering: Seq[SortOrder] = Seq.empty,
+ frame: WindowFrame = UnspecifiedFrame): Expression = {
+ WindowExpression(func, WindowSpecDefinition(partitioning, ordering, frame))
+ }
+
+ // Basic window testing.
+ assertEqual("foo(*) over w1", UnresolvedWindowExpression(func, WindowSpecReference("w1")))
+ assertEqual("foo(*) over ()", windowed())
+ assertEqual("foo(*) over (partition by a, b)", windowed(Seq('a, 'b)))
+ assertEqual("foo(*) over (distribute by a, b)", windowed(Seq('a, 'b)))
+ assertEqual("foo(*) over (cluster by a, b)", windowed(Seq('a, 'b)))
+ assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc )))
+ assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc )))
+ assertEqual("foo(*) over (partition by a, b order by c)", windowed(Seq('a, 'b), Seq('c.asc)))
+ assertEqual("foo(*) over (distribute by a, b sort by c)", windowed(Seq('a, 'b), Seq('c.asc)))
+
+ // Test use of expressions in window functions.
+ assertEqual(
+ "sum(product + 1) over (partition by ((product) + (1)) order by 2)",
+ WindowExpression('sum.function('product + 1),
+ WindowSpecDefinition(Seq('product + 1), Seq(Literal(2).asc), UnspecifiedFrame)))
+ assertEqual(
+ "sum(product + 1) over (partition by ((product / 2) + 1) order by 2)",
+ WindowExpression('sum.function('product + 1),
+ WindowSpecDefinition(Seq('product / 2 + 1), Seq(Literal(2).asc), UnspecifiedFrame)))
+
+ // Range/Row
+ val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame))
+ val boundaries = Seq(
+ ("10 preceding", ValuePreceding(10), CurrentRow),
+ ("3 + 1 following", ValueFollowing(4), CurrentRow), // Will fail during analysis
+ ("unbounded preceding", UnboundedPreceding, CurrentRow),
+ ("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis
+ ("between unbounded preceding and current row", UnboundedPreceding, CurrentRow),
+ ("between unbounded preceding and unbounded following",
+ UnboundedPreceding, UnboundedFollowing),
+ ("between 10 preceding and current row", ValuePreceding(10), CurrentRow),
+ ("between current row and 5 following", CurrentRow, ValueFollowing(5)),
+ ("between 10 preceding and 5 following", ValuePreceding(10), ValueFollowing(5))
+ )
+ frameTypes.foreach {
+ case (frameTypeSql, frameType) =>
+ boundaries.foreach {
+ case (boundarySql, begin, end) =>
+ val query = s"foo(*) over (partition by a order by b $frameTypeSql $boundarySql)"
+ val expr = windowed(Seq('a), Seq('b.asc), SpecifiedWindowFrame(frameType, begin, end))
+ assertEqual(query, expr)
+ }
+ }
+
+ // We cannot use non integer constants.
+ intercept("foo(*) over (partition by a order by b rows 10.0 preceding)",
+ "Frame bound value must be a constant integer.")
+
+ // We cannot use an arbitrary expression.
+ intercept("foo(*) over (partition by a order by b rows exp(b) preceding)",
+ "Frame bound value must be a constant integer.")
+ }
+
+ test("row constructor") {
+ // Note that '(a)' will be interpreted as a nested expression.
+ assertEqual("(a, b)", CreateStruct(Seq('a, 'b)))
+ assertEqual("(a, b, c)", CreateStruct(Seq('a, 'b, 'c)))
+ }
+
+ test("scalar sub-query") {
+ assertEqual(
+ "(select max(val) from tbl) > current",
+ ScalarSubquery(table("tbl").select('max.function('val))) > 'current)
+ assertEqual(
+ "a = (select b from s)",
+ 'a === ScalarSubquery(table("s").select('b)))
+ }
+
+ test("case when") {
+ assertEqual("case a when 1 then b when 2 then c else d end",
+ CaseKeyWhen('a, Seq(1, 'b, 2, 'c, 'd)))
+ assertEqual("case when a = 1 then b when a = 2 then c else d end",
+ CaseWhen(Seq(('a === 1, 'b.expr), ('a === 2, 'c.expr)), 'd))
+ }
+
+ test("dereference") {
+ assertEqual("a.b", UnresolvedAttribute("a.b"))
+ assertEqual("`select`.b", UnresolvedAttribute("select.b"))
+ assertEqual("(a + b).b", ('a + 'b).getField("b")) // This will fail analysis.
+ assertEqual("struct(a, b).b", 'struct.function('a, 'b).getField("b"))
+ }
+
+ test("reference") {
+ // Regular
+ assertEqual("a", 'a)
+
+ // Starting with a digit.
+ assertEqual("1a", Symbol("1a"))
+
+ // Quoted using a keyword.
+ assertEqual("`select`", 'select)
+
+ // Unquoted using an unreserved keyword.
+ assertEqual("columns", 'columns)
+ }
+
+ test("subscript") {
+ assertEqual("a[b]", 'a.getItem('b))
+ assertEqual("a[1 + 1]", 'a.getItem(Literal(1) + 1))
+ assertEqual("`c`.a[b]", UnresolvedAttribute("c.a").getItem('b))
+ }
+
+ test("parenthesis") {
+ assertEqual("(a)", 'a)
+ assertEqual("r * (a + b)", 'r * ('a + 'b))
+ }
+
+ test("type constructors") {
+ // Dates.
+ assertEqual("dAte '2016-03-11'", Literal(Date.valueOf("2016-03-11")))
+ intercept[IllegalArgumentException] {
+ parseExpression("DAtE 'mar 11 2016'")
+ }
+
+ // Timestamps.
+ assertEqual("tImEstAmp '2016-03-11 20:54:00.000'",
+ Literal(Timestamp.valueOf("2016-03-11 20:54:00.000")))
+ intercept[IllegalArgumentException] {
+ parseExpression("timestamP '2016-33-11 20:54:00.000'")
+ }
+
+ // Unsupported datatype.
+ intercept("GEO '(10,-6)'", "Literals of type 'GEO' are currently not supported.")
+ }
+
+ test("literals") {
+ // NULL
+ assertEqual("null", Literal(null))
+
+ // Boolean
+ assertEqual("trUe", Literal(true))
+ assertEqual("False", Literal(false))
+
+ // Integral should have the narrowest possible type
+ assertEqual("787324", Literal(787324))
+ assertEqual("7873247234798249234", Literal(7873247234798249234L))
+ assertEqual("78732472347982492793712334",
+ Literal(BigDecimal("78732472347982492793712334").underlying()))
+
+ // Decimal
+ assertEqual("7873247234798249279371.2334",
+ Literal(BigDecimal("7873247234798249279371.2334").underlying()))
+
+ // Scientific Decimal
+ assertEqual("9.0e1", 90d)
+ assertEqual(".9e+2", 90d)
+ assertEqual("0.9e+2", 90d)
+ assertEqual("900e-1", 90d)
+ assertEqual("900.0E-1", 90d)
+ assertEqual("9.e+1", 90d)
+ intercept(".e3")
+
+ // Tiny Int Literal
+ assertEqual("10Y", Literal(10.toByte))
+ intercept("-1000Y")
+
+ // Small Int Literal
+ assertEqual("10S", Literal(10.toShort))
+ intercept("40000S")
+
+ // Long Int Literal
+ assertEqual("10L", Literal(10L))
+ intercept("78732472347982492793712334L")
+
+ // Double Literal
+ assertEqual("10.0D", Literal(10.0D))
+ // TODO we need to figure out if we should throw an exception here!
+ assertEqual("1E309", Literal(Double.PositiveInfinity))
+ }
+
+ test("strings") {
+ // Single Strings.
+ assertEqual("\"hello\"", "hello")
+ assertEqual("'hello'", "hello")
+
+ // Multi-Strings.
+ assertEqual("\"hello\" 'world'", "helloworld")
+ assertEqual("'hello' \" \" 'world'", "hello world")
+
+ // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a
+ // regular '%'; to get the correct result you need to add another escaped '\'.
+ // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method?
+ assertEqual("'pattern%'", "pattern%")
+ assertEqual("'no-pattern\\%'", "no-pattern\\%")
+ assertEqual("'pattern\\\\%'", "pattern\\%")
+ assertEqual("'pattern\\\\\\%'", "pattern\\\\%")
+
+ // Escaped characters.
+ // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html
+ assertEqual("'\\0'", "\u0000") // ASCII NUL (X'00')
+ assertEqual("'\\''", "\'") // Single quote
+ assertEqual("'\\\"'", "\"") // Double quote
+ assertEqual("'\\b'", "\b") // Backspace
+ assertEqual("'\\n'", "\n") // Newline
+ assertEqual("'\\r'", "\r") // Carriage return
+ assertEqual("'\\t'", "\t") // Tab character
+ assertEqual("'\\Z'", "\u001A") // ASCII 26 - CTRL + Z (EOF on windows)
+
+ // Octals
+ assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!")
+
+ // Unicode
+ assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)")
+ }
+
+ test("intervals") {
+ def intervalLiteral(u: String, s: String): Literal = {
+ Literal(CalendarInterval.fromSingleUnitString(u, s))
+ }
+
+ // Empty interval statement
+ intercept("interval", "at least one time unit should be given for interval literal")
+
+ // Single Intervals.
+ val units = Seq(
+ "year",
+ "month",
+ "week",
+ "day",
+ "hour",
+ "minute",
+ "second",
+ "millisecond",
+ "microsecond")
+ val forms = Seq("", "s")
+ val values = Seq("0", "10", "-7", "21")
+ units.foreach { unit =>
+ forms.foreach { form =>
+ values.foreach { value =>
+ val expected = intervalLiteral(unit, value)
+ assertEqual(s"interval $value $unit$form", expected)
+ assertEqual(s"interval '$value' $unit$form", expected)
+ }
+ }
+ }
+
+ // Hive nanosecond notation.
+ assertEqual("interval 13.123456789 seconds", intervalLiteral("second", "13.123456789"))
+ assertEqual("interval -13.123456789 second", intervalLiteral("second", "-13.123456789"))
+
+ // Non Existing unit
+ intercept("interval 10 nanoseconds", "No interval can be constructed")
+
+ // Year-Month intervals.
+ val yearMonthValues = Seq("123-10", "496-0", "-2-3", "-123-0")
+ yearMonthValues.foreach { value =>
+ val result = Literal(CalendarInterval.fromYearMonthString(value))
+ assertEqual(s"interval '$value' year to month", result)
+ }
+
+ // Day-Time intervals.
+ val datTimeValues = Seq(
+ "99 11:22:33.123456789",
+ "-99 11:22:33.123456789",
+ "10 9:8:7.123456789",
+ "1 0:0:0",
+ "-1 0:0:0",
+ "1 0:0:1")
+ datTimeValues.foreach { value =>
+ val result = Literal(CalendarInterval.fromDayTimeString(value))
+ assertEqual(s"interval '$value' day to second", result)
+ }
+
+ // Unknown FROM TO intervals
+ intercept("interval 10 month to second", "Intervals FROM month TO second are not supported.")
+
+ // Composed intervals.
+ assertEqual(
+ "interval 3 months 22 seconds 1 millisecond",
+ Literal(new CalendarInterval(3, 22001000L)))
+ assertEqual(
+ "interval 3 years '-1-10' year to month 3 weeks '1 0:0:2' day to second",
+ Literal(new CalendarInterval(14,
+ 22 * CalendarInterval.MICROS_PER_DAY + 2 * CalendarInterval.MICROS_PER_SECOND)))
+ }
+
+ test("composed expressions") {
+ assertEqual("1 + r.r As q", (Literal(1) + UnresolvedAttribute("r.r")).as("q"))
+ assertEqual("1 - f('o', o(bar))", Literal(1) - 'f.function("o", 'o.function('bar)))
+ intercept("1 - f('o', o(bar)) hello * world", "mismatched input '*'")
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala
new file mode 100644
index 0000000000..d090daf7b4
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser
+
+import org.apache.spark.SparkFunSuite
+
+class ParserUtilsSuite extends SparkFunSuite {
+
+ import ParserUtils._
+
+ test("unescapeSQLString") {
+ // scalastyle:off nonascii
+
+ // String not including escaped characters and enclosed by double quotes.
+ assert(unescapeSQLString(""""abcdefg"""") == "abcdefg")
+
+ // String enclosed by single quotes.
+ assert(unescapeSQLString("""'C0FFEE'""") == "C0FFEE")
+
+ // Strings including single escaped characters.
+ assert(unescapeSQLString("""'\0'""") == "\u0000")
+ assert(unescapeSQLString(""""\'"""") == "\'")
+ assert(unescapeSQLString("""'\"'""") == "\"")
+ assert(unescapeSQLString(""""\b"""") == "\b")
+ assert(unescapeSQLString("""'\n'""") == "\n")
+ assert(unescapeSQLString(""""\r"""") == "\r")
+ assert(unescapeSQLString("""'\t'""") == "\t")
+ assert(unescapeSQLString(""""\Z"""") == "\u001A")
+ assert(unescapeSQLString("""'\\'""") == "\\")
+ assert(unescapeSQLString(""""\%"""") == "\\%")
+ assert(unescapeSQLString("""'\_'""") == "\\_")
+
+ // String including '\000' style literal characters.
+ assert(unescapeSQLString("""'3 + 5 = \070'""") == "3 + 5 = \u0038")
+ assert(unescapeSQLString(""""\000"""") == "\u0000")
+
+ // String including invalid '\000' style literal characters.
+ assert(unescapeSQLString(""""\256"""") == "256")
+
+ // String including a '\u0000' style literal characters (\u732B is a cat in Kanji).
+ assert(unescapeSQLString(""""How cute \u732B are"""") == "How cute \u732B are")
+
+ // String including a surrogate pair character
+ // (\uD867\uDE3D is Okhotsk atka mackerel in Kanji).
+ assert(unescapeSQLString(""""\uD867\uDE3D is a fish"""") == "\uD867\uDE3D is a fish")
+
+ // scalastyle:on nonascii
+ }
+
+ // TODO: Add test cases for other methods in ParserUtils
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
new file mode 100644
index 0000000000..411e2372f2
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -0,0 +1,431 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.types.{BooleanType, IntegerType}
+
+class PlanParserSuite extends PlanTest {
+ import CatalystSqlParser._
+ import org.apache.spark.sql.catalyst.dsl.expressions._
+ import org.apache.spark.sql.catalyst.dsl.plans._
+
+ def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = {
+ comparePlans(parsePlan(sqlCommand), plan)
+ }
+
+ def intercept(sqlCommand: String, messages: String*): Unit = {
+ val e = intercept[ParseException](parsePlan(sqlCommand))
+ messages.foreach { message =>
+ assert(e.message.contains(message))
+ }
+ }
+
+ test("case insensitive") {
+ val plan = table("a").select(star())
+ assertEqual("sELEct * FroM a", plan)
+ assertEqual("select * fRoM a", plan)
+ assertEqual("SELECT * FROM a", plan)
+ }
+
+ test("show functions") {
+ assertEqual("show functions", ShowFunctions(None, None))
+ assertEqual("show functions foo", ShowFunctions(None, Some("foo")))
+ assertEqual("show functions foo.bar", ShowFunctions(Some("foo"), Some("bar")))
+ assertEqual("show functions 'foo\\\\.*'", ShowFunctions(None, Some("foo\\.*")))
+ intercept("show functions foo.bar.baz", "SHOW FUNCTIONS unsupported name")
+ }
+
+ test("describe function") {
+ assertEqual("describe function bar", DescribeFunction("bar", isExtended = false))
+ assertEqual("describe function extended bar", DescribeFunction("bar", isExtended = true))
+ assertEqual("describe function foo.bar", DescribeFunction("foo.bar", isExtended = false))
+ assertEqual("describe function extended f.bar", DescribeFunction("f.bar", isExtended = true))
+ }
+
+ test("set operations") {
+ val a = table("a").select(star())
+ val b = table("b").select(star())
+
+ assertEqual("select * from a union select * from b", Distinct(a.union(b)))
+ assertEqual("select * from a union distinct select * from b", Distinct(a.union(b)))
+ assertEqual("select * from a union all select * from b", a.union(b))
+ assertEqual("select * from a except select * from b", a.except(b))
+ intercept("select * from a except all select * from b", "EXCEPT ALL is not supported.")
+ assertEqual("select * from a except distinct select * from b", a.except(b))
+ assertEqual("select * from a intersect select * from b", a.intersect(b))
+ intercept("select * from a intersect all select * from b", "INTERSECT ALL is not supported.")
+ assertEqual("select * from a intersect distinct select * from b", a.intersect(b))
+ }
+
+ test("common table expressions") {
+ def cte(plan: LogicalPlan, namedPlans: (String, LogicalPlan)*): With = {
+ val ctes = namedPlans.map {
+ case (name, cte) =>
+ name -> SubqueryAlias(name, cte)
+ }.toMap
+ With(plan, ctes)
+ }
+ assertEqual(
+ "with cte1 as (select * from a) select * from cte1",
+ cte(table("cte1").select(star()), "cte1" -> table("a").select(star())))
+ assertEqual(
+ "with cte1 (select 1) select * from cte1",
+ cte(table("cte1").select(star()), "cte1" -> OneRowRelation.select(1)))
+ assertEqual(
+ "with cte1 (select 1), cte2 as (select * from cte1) select * from cte2",
+ cte(table("cte2").select(star()),
+ "cte1" -> OneRowRelation.select(1),
+ "cte2" -> table("cte1").select(star())))
+ intercept(
+ "with cte1 (select 1), cte1 as (select 1 from cte1) select * from cte1",
+ "Name 'cte1' is used for multiple common table expressions")
+ }
+
+ test("simple select query") {
+ assertEqual("select 1", OneRowRelation.select(1))
+ assertEqual("select a, b", OneRowRelation.select('a, 'b))
+ assertEqual("select a, b from db.c", table("db", "c").select('a, 'b))
+ assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b))
+ assertEqual(
+ "select a, b from db.c having x < 1",
+ table("db", "c").select('a, 'b).where(('x < 1).cast(BooleanType)))
+ assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b)))
+ assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b))
+ }
+
+ test("reverse select query") {
+ assertEqual("from a", table("a"))
+ assertEqual("from a select b, c", table("a").select('b, 'c))
+ assertEqual(
+ "from db.a select b, c where d < 1", table("db", "a").where('d < 1).select('b, 'c))
+ assertEqual("from a select distinct b, c", Distinct(table("a").select('b, 'c)))
+ assertEqual(
+ "from (from a union all from b) c select *",
+ table("a").union(table("b")).as("c").select(star()))
+ }
+
+ test("multi select query") {
+ assertEqual(
+ "from a select * select * where s < 10",
+ table("a").select(star()).union(table("a").where('s < 10).select(star())))
+ intercept(
+ "from a select * select * from x where a.s < 10",
+ "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements")
+ assertEqual(
+ "from a insert into tbl1 select * insert into tbl2 select * where s < 10",
+ table("a").select(star()).insertInto("tbl1").union(
+ table("a").where('s < 10).select(star()).insertInto("tbl2")))
+ }
+
+ test("query organization") {
+ // Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows
+ val baseSql = "select * from t"
+ val basePlan = table("t").select(star())
+
+ val ws = Map("w1" -> WindowSpecDefinition(Seq.empty, Seq.empty, UnspecifiedFrame))
+ val limitWindowClauses = Seq(
+ ("", (p: LogicalPlan) => p),
+ (" limit 10", (p: LogicalPlan) => p.limit(10)),
+ (" window w1 as ()", (p: LogicalPlan) => WithWindowDefinition(ws, p)),
+ (" window w1 as () limit 10", (p: LogicalPlan) => WithWindowDefinition(ws, p).limit(10))
+ )
+
+ val orderSortDistrClusterClauses = Seq(
+ ("", basePlan),
+ (" order by a, b desc", basePlan.orderBy('a.asc, 'b.desc)),
+ (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)),
+ (" distribute by a, b", basePlan.distribute('a, 'b)),
+ (" distribute by a sort by b", basePlan.distribute('a).sortBy('b.asc)),
+ (" cluster by a, b", basePlan.distribute('a, 'b).sortBy('a.asc, 'b.asc))
+ )
+
+ orderSortDistrClusterClauses.foreach {
+ case (s1, p1) =>
+ limitWindowClauses.foreach {
+ case (s2, pf2) =>
+ assertEqual(baseSql + s1 + s2, pf2(p1))
+ }
+ }
+
+ val msg = "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported"
+ intercept(s"$baseSql order by a sort by a", msg)
+ intercept(s"$baseSql cluster by a distribute by a", msg)
+ intercept(s"$baseSql order by a cluster by a", msg)
+ intercept(s"$baseSql order by a distribute by a", msg)
+ }
+
+ test("insert into") {
+ val sql = "select * from t"
+ val plan = table("t").select(star())
+ def insert(
+ partition: Map[String, Option[String]],
+ overwrite: Boolean = false,
+ ifNotExists: Boolean = false): LogicalPlan =
+ InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists)
+
+ // Single inserts
+ assertEqual(s"insert overwrite table s $sql",
+ insert(Map.empty, overwrite = true))
+ assertEqual(s"insert overwrite table s if not exists $sql",
+ insert(Map.empty, overwrite = true, ifNotExists = true))
+ assertEqual(s"insert into s $sql",
+ insert(Map.empty))
+ assertEqual(s"insert into table s partition (c = 'd', e = 1) $sql",
+ insert(Map("c" -> Option("d"), "e" -> Option("1"))))
+ assertEqual(s"insert overwrite table s partition (c = 'd', x) if not exists $sql",
+ insert(Map("c" -> Option("d"), "x" -> None), overwrite = true, ifNotExists = true))
+
+ // Multi insert
+ val plan2 = table("t").where('x > 5).select(star())
+ assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5",
+ InsertIntoTable(
+ table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false).union(
+ InsertIntoTable(
+ table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false)))
+ }
+
+ test("aggregation") {
+ val sql = "select a, b, sum(c) as c from d group by a, b"
+
+ // Normal
+ assertEqual(sql, table("d").groupBy('a, 'b)('a, 'b, 'sum.function('c).as("c")))
+
+ // Cube
+ assertEqual(s"$sql with cube",
+ table("d").groupBy(Cube(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c")))
+
+ // Rollup
+ assertEqual(s"$sql with rollup",
+ table("d").groupBy(Rollup(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c")))
+
+ // Grouping Sets
+ assertEqual(s"$sql grouping sets((a, b), (a), ())",
+ GroupingSets(Seq(0, 1, 3), Seq('a, 'b), table("d"), Seq('a, 'b, 'sum.function('c).as("c"))))
+ intercept(s"$sql grouping sets((a, b), (c), ())",
+ "c doesn't show up in the GROUP BY list")
+ }
+
+ test("limit") {
+ val sql = "select * from t"
+ val plan = table("t").select(star())
+ assertEqual(s"$sql limit 10", plan.limit(10))
+ assertEqual(s"$sql limit cast(9 / 4 as int)", plan.limit(Cast(Literal(9) / 4, IntegerType)))
+ }
+
+ test("window spec") {
+ // Note that WindowSpecs are testing in the ExpressionParserSuite
+ val sql = "select * from t"
+ val plan = table("t").select(star())
+ val spec = WindowSpecDefinition(Seq('a, 'b), Seq('c.asc),
+ SpecifiedWindowFrame(RowFrame, ValuePreceding(1), ValueFollowing(1)))
+
+ // Test window resolution.
+ val ws1 = Map("w1" -> spec, "w2" -> spec, "w3" -> spec)
+ assertEqual(
+ s"""$sql
+ |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following),
+ | w2 as w1,
+ | w3 as w1""".stripMargin,
+ WithWindowDefinition(ws1, plan))
+
+ // Fail with no reference.
+ intercept(s"$sql window w2 as w1", "Cannot resolve window reference 'w1'")
+
+ // Fail when resolved reference is not a window spec.
+ intercept(
+ s"""$sql
+ |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following),
+ | w2 as w1,
+ | w3 as w2""".stripMargin,
+ "Window reference 'w2' is not a window specification"
+ )
+ }
+
+ test("lateral view") {
+ // Single lateral view
+ assertEqual(
+ "select * from t lateral view explode(x) expl as x",
+ table("t")
+ .generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x"))
+ .select(star()))
+
+ // Multiple lateral views
+ assertEqual(
+ """select *
+ |from t
+ |lateral view explode(x) expl
+ |lateral view outer json_tuple(x, y) jtup q, z""".stripMargin,
+ table("t")
+ .generate(Explode('x), join = true, outer = false, Some("expl"), Seq.empty)
+ .generate(JsonTuple(Seq('x, 'y)), join = true, outer = true, Some("jtup"), Seq("q", "z"))
+ .select(star()))
+
+ // Multi-Insert lateral views.
+ val from = table("t1").generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x"))
+ assertEqual(
+ """from t1
+ |lateral view explode(x) expl as x
+ |insert into t2
+ |select *
+ |lateral view json_tuple(x, y) jtup q, z
+ |insert into t3
+ |select *
+ |where s < 10
+ """.stripMargin,
+ Union(from
+ .generate(JsonTuple(Seq('x, 'y)), join = true, outer = false, Some("jtup"), Seq("q", "z"))
+ .select(star())
+ .insertInto("t2"),
+ from.where('s < 10).select(star()).insertInto("t3")))
+
+ // Unresolved generator.
+ val expected = table("t")
+ .generate(
+ UnresolvedGenerator("posexplode", Seq('x)),
+ join = true,
+ outer = false,
+ Some("posexpl"),
+ Seq("x", "y"))
+ .select(star())
+ assertEqual(
+ "select * from t lateral view posexplode(x) posexpl as x, y",
+ expected)
+ }
+
+ test("joins") {
+ // Test single joins.
+ val testUnconditionalJoin = (sql: String, jt: JoinType) => {
+ assertEqual(
+ s"select * from t as tt $sql u",
+ table("t").as("tt").join(table("u"), jt, None).select(star()))
+ }
+ val testConditionalJoin = (sql: String, jt: JoinType) => {
+ assertEqual(
+ s"select * from t $sql u as uu on a = b",
+ table("t").join(table("u").as("uu"), jt, Option('a === 'b)).select(star()))
+ }
+ val testNaturalJoin = (sql: String, jt: JoinType) => {
+ assertEqual(
+ s"select * from t tt natural $sql u as uu",
+ table("t").as("tt").join(table("u").as("uu"), NaturalJoin(jt), None).select(star()))
+ }
+ val testUsingJoin = (sql: String, jt: JoinType) => {
+ assertEqual(
+ s"select * from t $sql u using(a, b)",
+ table("t").join(table("u"), UsingJoin(jt, Seq('a.attr, 'b.attr)), None).select(star()))
+ }
+ val testAll = Seq(testUnconditionalJoin, testConditionalJoin, testNaturalJoin, testUsingJoin)
+ val testExistence = Seq(testUnconditionalJoin, testConditionalJoin, testUsingJoin)
+ def test(sql: String, jt: JoinType, tests: Seq[(String, JoinType) => Unit]): Unit = {
+ tests.foreach(_(sql, jt))
+ }
+ test("cross join", Inner, Seq(testUnconditionalJoin))
+ test(",", Inner, Seq(testUnconditionalJoin))
+ test("join", Inner, testAll)
+ test("inner join", Inner, testAll)
+ test("left join", LeftOuter, testAll)
+ test("left outer join", LeftOuter, testAll)
+ test("right join", RightOuter, testAll)
+ test("right outer join", RightOuter, testAll)
+ test("full join", FullOuter, testAll)
+ test("full outer join", FullOuter, testAll)
+ test("left semi join", LeftSemi, testExistence)
+ test("left anti join", LeftAnti, testExistence)
+ test("anti join", LeftAnti, testExistence)
+
+ // Test multiple consecutive joins
+ assertEqual(
+ "select * from a join b join c right join d",
+ table("a").join(table("b")).join(table("c")).join(table("d"), RightOuter).select(star()))
+ }
+
+ test("sampled relations") {
+ val sql = "select * from t"
+ assertEqual(s"$sql tablesample(100 rows)",
+ table("t").limit(100).select(star()))
+ assertEqual(s"$sql tablesample(43 percent) as x",
+ Sample(0, .43d, withReplacement = false, 10L, table("t").as("x"))(true).select(star()))
+ assertEqual(s"$sql tablesample(bucket 4 out of 10) as x",
+ Sample(0, .4d, withReplacement = false, 10L, table("t").as("x"))(true).select(star()))
+ intercept(s"$sql tablesample(bucket 4 out of 10 on x) as x",
+ "TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported")
+ intercept(s"$sql tablesample(bucket 11 out of 10) as x",
+ s"Sampling fraction (${11.0/10.0}) must be on interval [0, 1]")
+ }
+
+ test("sub-query") {
+ val plan = table("t0").select('id)
+ assertEqual("select id from (t0)", plan)
+ assertEqual("select id from ((((((t0))))))", plan)
+ assertEqual(
+ "(select * from t1) union distinct (select * from t2)",
+ Distinct(table("t1").select(star()).union(table("t2").select(star()))))
+ assertEqual(
+ "select * from ((select * from t1) union (select * from t2)) t",
+ Distinct(
+ table("t1").select(star()).union(table("t2").select(star()))).as("t").select(star()))
+ assertEqual(
+ """select id
+ |from (((select id from t0)
+ | union all
+ | (select id from t0))
+ | union all
+ | (select id from t0)) as u_1
+ """.stripMargin,
+ plan.union(plan).union(plan).as("u_1").select('id))
+ }
+
+ test("scalar sub-query") {
+ assertEqual(
+ "select (select max(b) from s) ss from t",
+ table("t").select(ScalarSubquery(table("s").select('max.function('b))).as("ss")))
+ assertEqual(
+ "select * from t where a = (select b from s)",
+ table("t").where('a === ScalarSubquery(table("s").select('b))).select(star()))
+ assertEqual(
+ "select g from t group by g having a > (select b from s)",
+ table("t")
+ .groupBy('g)('g)
+ .where(('a > ScalarSubquery(table("s").select('b))).cast(BooleanType)))
+ }
+
+ test("table reference") {
+ assertEqual("table t", table("t"))
+ assertEqual("table d.t", table("d", "t"))
+ }
+
+ test("inline table") {
+ assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows(
+ Seq('col1.int),
+ Seq(1, 2, 3, 4).map(x => Row(x))))
+ assertEqual(
+ "values (1, 'a'), (2, 'b'), (3, 'c') as tbl(a, b)",
+ LocalRelation.fromExternalRows(
+ Seq('a.int, 'b.string),
+ Seq((1, "a"), (2, "b"), (3, "c")).map(x => Row(x._1, x._2))).as("tbl"))
+ intercept("values (a, 'a'), (b, 'b')",
+ "All expressions in an inline table must be constants.")
+ intercept("values (1, 'a'), (2, 'b') as tbl(a, b, c)",
+ "Number of aliases must match the number of fields in an inline table.")
+ intercept[ArrayIndexOutOfBoundsException](parsePlan("values (1, 'a'), (2, 'b', 5Y)"))
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala
new file mode 100644
index 0000000000..297b1931a9
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.TableIdentifier
+
+class TableIdentifierParserSuite extends SparkFunSuite {
+ import CatalystSqlParser._
+
+ test("table identifier") {
+ // Regular names.
+ assert(TableIdentifier("q") === parseTableIdentifier("q"))
+ assert(TableIdentifier("q", Option("d")) === parseTableIdentifier("d.q"))
+
+ // Illegal names.
+ intercept[ParseException](parseTableIdentifier(""))
+ intercept[ParseException](parseTableIdentifier("d.q.g"))
+
+ // SQL Keywords.
+ val keywords = Seq("select", "from", "where", "left", "right")
+ keywords.foreach { keyword =>
+ intercept[ParseException](parseTableIdentifier(keyword))
+ assert(TableIdentifier(keyword) === parseTableIdentifier(s"`$keyword`"))
+ assert(TableIdentifier(keyword, Option("db")) === parseTableIdentifier(s"db.`$keyword`"))
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
index e5063599a3..81cc6b123c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.types.{IntegerType, StringType}
+import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType}
class ConstraintPropagationSuite extends SparkFunSuite {
@@ -88,6 +88,33 @@ class ConstraintPropagationSuite extends SparkFunSuite {
IsNotNull(resolveColumn(aliasedRelation.analyze, "a")))))
}
+ test("propagating constraints in expand") {
+ val tr = LocalRelation('a.int, 'b.int, 'c.int)
+
+ assert(tr.analyze.constraints.isEmpty)
+
+ // We add IsNotNull constraints for 'a, 'b and 'c into LocalRelation
+ // by creating notNullRelation.
+ val notNullRelation = tr.where('c.attr > 10 && 'a.attr < 5 && 'b.attr > 2)
+ verifyConstraints(notNullRelation.analyze.constraints,
+ ExpressionSet(Seq(resolveColumn(notNullRelation.analyze, "c") > 10,
+ IsNotNull(resolveColumn(notNullRelation.analyze, "c")),
+ resolveColumn(notNullRelation.analyze, "a") < 5,
+ IsNotNull(resolveColumn(notNullRelation.analyze, "a")),
+ resolveColumn(notNullRelation.analyze, "b") > 2,
+ IsNotNull(resolveColumn(notNullRelation.analyze, "b")))))
+
+ val expand = Expand(
+ Seq(
+ Seq('c, Literal.create(null, StringType), 1),
+ Seq('c, 'a, 2)),
+ Seq('c, 'a, 'gid.int),
+ Project(Seq('a, 'c),
+ notNullRelation))
+ verifyConstraints(expand.analyze.constraints,
+ ExpressionSet(Seq.empty[Expression]))
+ }
+
test("propagating constraints in aliases") {
val tr = LocalRelation('a.int, 'b.string, 'c.int)
@@ -121,6 +148,20 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.analyze.constraints,
ExpressionSet(Seq(resolveColumn(tr1, "a") > 10,
IsNotNull(resolveColumn(tr1, "a")))))
+
+ val a = resolveColumn(tr1, "a")
+ verifyConstraints(tr1
+ .where('a.attr > 10)
+ .union(tr2.where('d.attr > 11))
+ .analyze.constraints,
+ ExpressionSet(Seq(a > 10 || a > 11, IsNotNull(a))))
+
+ val b = resolveColumn(tr1, "b")
+ verifyConstraints(tr1
+ .where('a.attr > 10 && 'b.attr < 10)
+ .union(tr2.where('d.attr > 11 && 'e.attr < 11))
+ .analyze.constraints,
+ ExpressionSet(Seq(a > 10 || a > 11, b < 10 || b < 11, IsNotNull(a), IsNotNull(b))))
}
test("propagating constraints in intersect") {
@@ -219,6 +260,89 @@ class ConstraintPropagationSuite extends SparkFunSuite {
IsNotNull(resolveColumn(tr, "b")))))
}
+ test("infer constraints on cast") {
+ val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int)
+ verifyConstraints(
+ tr.where('a.attr === 'b.attr &&
+ 'c.attr + 100 > 'd.attr &&
+ IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType))).analyze.constraints,
+ ExpressionSet(Seq(Cast(resolveColumn(tr, "a"), LongType) === resolveColumn(tr, "b"),
+ Cast(resolveColumn(tr, "c") + 100, LongType) > resolveColumn(tr, "d"),
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "b")),
+ IsNotNull(resolveColumn(tr, "c")),
+ IsNotNull(resolveColumn(tr, "d")),
+ IsNotNull(resolveColumn(tr, "e")),
+ IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType)))))
+ }
+
+ test("infer isnotnull constraints from compound expressions") {
+ val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int)
+ verifyConstraints(
+ tr.where('a.attr + 'b.attr === 'c.attr &&
+ IsNotNull(
+ Cast(
+ Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))).analyze.constraints,
+ ExpressionSet(Seq(
+ Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b") ===
+ Cast(resolveColumn(tr, "c"), LongType),
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "b")),
+ IsNotNull(resolveColumn(tr, "c")),
+ IsNotNull(resolveColumn(tr, "e")),
+ IsNotNull(Cast(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType)))))
+
+ verifyConstraints(
+ tr.where(('a.attr * 'b.attr + 100) === 'c.attr && 'd / 10 === 'e).analyze.constraints,
+ ExpressionSet(Seq(
+ Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") + Cast(100, LongType) ===
+ Cast(resolveColumn(tr, "c"), LongType),
+ Cast(resolveColumn(tr, "d"), DoubleType) /
+ Cast(Cast(10, LongType), DoubleType) ===
+ Cast(resolveColumn(tr, "e"), DoubleType),
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "b")),
+ IsNotNull(resolveColumn(tr, "c")),
+ IsNotNull(resolveColumn(tr, "d")),
+ IsNotNull(resolveColumn(tr, "e")))))
+
+ verifyConstraints(
+ tr.where(('a.attr * 'b.attr - 10) >= 'c.attr && 'd / 10 < 'e).analyze.constraints,
+ ExpressionSet(Seq(
+ Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") - Cast(10, LongType) >=
+ Cast(resolveColumn(tr, "c"), LongType),
+ Cast(resolveColumn(tr, "d"), DoubleType) /
+ Cast(Cast(10, LongType), DoubleType) <
+ Cast(resolveColumn(tr, "e"), DoubleType),
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "b")),
+ IsNotNull(resolveColumn(tr, "c")),
+ IsNotNull(resolveColumn(tr, "d")),
+ IsNotNull(resolveColumn(tr, "e")))))
+
+ verifyConstraints(
+ tr.where('a.attr + 'b.attr - 'c.attr * 'd.attr > 'e.attr * 1000).analyze.constraints,
+ ExpressionSet(Seq(
+ (Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b")) -
+ (Cast(resolveColumn(tr, "c"), LongType) * resolveColumn(tr, "d")) >
+ Cast(resolveColumn(tr, "e") * 1000, LongType),
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "b")),
+ IsNotNull(resolveColumn(tr, "c")),
+ IsNotNull(resolveColumn(tr, "d")),
+ IsNotNull(resolveColumn(tr, "e")))))
+
+ // The constraint IsNotNull(IsNotNull(expr)) doesn't guarantee expr is not null.
+ verifyConstraints(
+ tr.where('a.attr === 'c.attr &&
+ IsNotNull(IsNotNull(resolveColumn(tr, "b")))).analyze.constraints,
+ ExpressionSet(Seq(
+ resolveColumn(tr, "a") === resolveColumn(tr, "c"),
+ IsNotNull(IsNotNull(resolveColumn(tr, "b"))),
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "c")))))
+ }
+
test("infer IsNotNull constraints from non-nullable attributes") {
val tr = LocalRelation('a.int, AttributeReference("b", IntegerType, nullable = false)(),
AttributeReference("c", StringType, nullable = false)())
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index 0541844e0b..7191936699 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -19,7 +19,8 @@ package org.apache.spark.sql.catalyst.plans
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation}
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample}
import org.apache.spark.sql.catalyst.util._
/**
@@ -32,29 +33,37 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
*/
protected def normalizeExprIds(plan: LogicalPlan) = {
plan transformAllExpressions {
+ case s: ScalarSubquery =>
+ ScalarSubquery(s.query, ExprId(0))
case a: AttributeReference =>
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
case a: Alias =>
Alias(a.child, a.name)(exprId = ExprId(0))
+ case ae: AggregateExpression =>
+ ae.copy(resultId = ExprId(0))
}
}
/**
- * Normalizes the filter conditions that appear in the plan. For instance,
- * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2)
- * etc., will all now be equivalent.
+ * Normalizes plans:
+ * - Filter the filter conditions that appear in a plan. For instance,
+ * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2)
+ * etc., will all now be equivalent.
+ * - Sample the seed will replaced by 0L.
*/
- private def normalizeFilters(plan: LogicalPlan) = {
+ private def normalizePlan(plan: LogicalPlan): LogicalPlan = {
plan transform {
case filter @ Filter(condition: Expression, child: LogicalPlan) =>
Filter(splitConjunctivePredicates(condition).sortBy(_.hashCode()).reduce(And), child)
+ case sample: Sample =>
+ sample.copy(seed = 0L)(true)
}
}
/** Fails the test if the two plans do not match */
protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) {
- val normalized1 = normalizeFilters(normalizeExprIds(plan1))
- val normalized2 = normalizeFilters(normalizeExprIds(plan2))
+ val normalized1 = normalizePlan(normalizeExprIds(plan1))
+ val normalized2 = normalizePlan(normalizeExprIds(plan2))
if (normalized1 != normalized2) {
fail(
s"""
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala
index 37941cf34e..467f76193c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Union}
import org.apache.spark.sql.catalyst.util._
/**
@@ -61,4 +61,9 @@ class SameResultSuite extends SparkFunSuite {
test("sorts") {
assertSameResult(testRelation.orderBy('a.asc), testRelation2.orderBy('a.asc))
}
+
+ test("union") {
+ assertSameResult(Union(Seq(testRelation, testRelation2)),
+ Union(Seq(testRelation2, testRelation)))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala
index d6f273f9e5..2ffc18a8d1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala
@@ -31,4 +31,16 @@ class StringUtilsSuite extends SparkFunSuite {
assert(escapeLikeRegex("**") === "(?s)\\Q*\\E\\Q*\\E")
assert(escapeLikeRegex("a_b") === "(?s)\\Qa\\E.\\Qb\\E")
}
+
+ test("filter pattern") {
+ val names = Seq("a1", "a2", "b2", "c3")
+ assert(filterPattern(names, " * ") === Seq("a1", "a2", "b2", "c3"))
+ assert(filterPattern(names, "*a*") === Seq("a1", "a2"))
+ assert(filterPattern(names, " *a* ") === Seq("a1", "a2"))
+ assert(filterPattern(names, " a* ") === Seq("a1", "a2"))
+ assert(filterPattern(names, " a.* ") === Seq("a1", "a2"))
+ assert(filterPattern(names, " B.*|a* ") === Seq("a1", "a2", "b2"))
+ assert(filterPattern(names, " a. ") === Seq("a1", "a2"))
+ assert(filterPattern(names, " d* ") === Nil)
+ }
}
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index f347a9929c..e1071ebfb5 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -39,7 +39,7 @@
<dependency>
<groupId>com.univocity</groupId>
<artifactId>univocity-parsers</artifactId>
- <version>1.5.6</version>
+ <version>2.0.2</version>
<type>jar</type>
</dependency>
<dependency>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java
index dbea8521be..086547c793 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java
@@ -36,6 +36,8 @@ public abstract class BufferedRowIterator {
protected UnsafeRow unsafeRow = new UnsafeRow(0);
private long startTimeNs = System.nanoTime();
+ protected int partitionIndex = -1;
+
public boolean hasNext() throws IOException {
if (currentRows.isEmpty()) {
processNext();
@@ -58,7 +60,7 @@ public abstract class BufferedRowIterator {
/**
* Initializes from array of iterators of InternalRow.
*/
- public abstract void init(Iterator<InternalRow> iters[]);
+ public abstract void init(int index, Iterator<InternalRow>[] iters);
/**
* Append a row to currentRows.
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 8882903bbf..1f1b5389aa 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -134,7 +134,7 @@ public final class UnsafeFixedWidthAggregationMap {
if (!loc.isDefined()) {
// This is the first time that we've seen this grouping key, so we'll insert a copy of the
// empty aggregation buffer into the map:
- boolean putSucceeded = loc.putNewKey(
+ boolean putSucceeded = loc.append(
key.getBaseObject(),
key.getBaseOffset(),
key.getSizeInBytes(),
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index d3bfb00b3f..8132bba04c 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -272,5 +272,5 @@ public final class UnsafeKVExternalSorter {
public void close() {
cleanupResources();
}
- };
+ }
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
index 6cc2fda587..ea37a08ab5 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
@@ -27,6 +27,7 @@ import org.apache.parquet.column.Encoding;
import org.apache.parquet.column.page.*;
import org.apache.parquet.column.values.ValuesReader;
import org.apache.parquet.io.api.Binary;
+import org.apache.parquet.schema.PrimitiveType;
import org.apache.spark.sql.execution.vectorized.ColumnVector;
import org.apache.spark.sql.types.DataTypes;
@@ -115,57 +116,6 @@ public class VectorizedColumnReader {
}
/**
- * TODO: Hoist the useDictionary branch to decode*Batch and make the batch page aligned.
- */
- public boolean nextBoolean() {
- if (!useDictionary) {
- return dataColumn.readBoolean();
- } else {
- return dictionary.decodeToBoolean(dataColumn.readValueDictionaryId());
- }
- }
-
- public int nextInt() {
- if (!useDictionary) {
- return dataColumn.readInteger();
- } else {
- return dictionary.decodeToInt(dataColumn.readValueDictionaryId());
- }
- }
-
- public long nextLong() {
- if (!useDictionary) {
- return dataColumn.readLong();
- } else {
- return dictionary.decodeToLong(dataColumn.readValueDictionaryId());
- }
- }
-
- public float nextFloat() {
- if (!useDictionary) {
- return dataColumn.readFloat();
- } else {
- return dictionary.decodeToFloat(dataColumn.readValueDictionaryId());
- }
- }
-
- public double nextDouble() {
- if (!useDictionary) {
- return dataColumn.readDouble();
- } else {
- return dictionary.decodeToDouble(dataColumn.readValueDictionaryId());
- }
- }
-
- public Binary nextBinary() {
- if (!useDictionary) {
- return dataColumn.readBytes();
- } else {
- return dictionary.decodeToBinary(dataColumn.readValueDictionaryId());
- }
- }
-
- /**
* Advances to the next value. Returns true if the value is non-null.
*/
private boolean next() throws IOException {
@@ -200,8 +150,26 @@ public class VectorizedColumnReader {
ColumnVector dictionaryIds = column.reserveDictionaryIds(total);
defColumn.readIntegers(
num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
- decodeDictionaryIds(rowId, num, column, dictionaryIds);
+
+ if (column.hasDictionary() || (rowId == 0 &&
+ (descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT32 ||
+ descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT64 ||
+ descriptor.getType() == PrimitiveType.PrimitiveTypeName.FLOAT ||
+ descriptor.getType() == PrimitiveType.PrimitiveTypeName.DOUBLE ||
+ descriptor.getType() == PrimitiveType.PrimitiveTypeName.BINARY))) {
+ // Column vector supports lazy decoding of dictionary values so just set the dictionary.
+ // We can't do this if rowId != 0 AND the column doesn't have a dictionary (i.e. some
+ // non-dictionary encoded values have already been added).
+ column.setDictionary(dictionary);
+ } else {
+ decodeDictionaryIds(rowId, num, column, dictionaryIds);
+ }
} else {
+ if (column.hasDictionary() && rowId != 0) {
+ // This batch already has dictionary encoded values but this new page is not. The batch
+ // does not support a mix of dictionary and not so we will decode the dictionary.
+ decodeDictionaryIds(0, rowId, column, column.getDictionaryIds());
+ }
column.setDictionary(null);
switch (descriptor.getType()) {
case BOOLEAN:
@@ -246,11 +214,45 @@ public class VectorizedColumnReader {
ColumnVector dictionaryIds) {
switch (descriptor.getType()) {
case INT32:
+ if (column.dataType() == DataTypes.IntegerType ||
+ DecimalType.is32BitDecimalType(column.dataType())) {
+ for (int i = rowId; i < rowId + num; ++i) {
+ column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i)));
+ }
+ } else if (column.dataType() == DataTypes.ByteType) {
+ for (int i = rowId; i < rowId + num; ++i) {
+ column.putByte(i, (byte) dictionary.decodeToInt(dictionaryIds.getInt(i)));
+ }
+ } else if (column.dataType() == DataTypes.ShortType) {
+ for (int i = rowId; i < rowId + num; ++i) {
+ column.putShort(i, (short) dictionary.decodeToInt(dictionaryIds.getInt(i)));
+ }
+ } else {
+ throw new NotImplementedException("Unimplemented type: " + column.dataType());
+ }
+ break;
+
case INT64:
+ if (column.dataType() == DataTypes.LongType ||
+ DecimalType.is64BitDecimalType(column.dataType())) {
+ for (int i = rowId; i < rowId + num; ++i) {
+ column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i)));
+ }
+ } else {
+ throw new NotImplementedException("Unimplemented type: " + column.dataType());
+ }
+ break;
+
case FLOAT:
+ for (int i = rowId; i < rowId + num; ++i) {
+ column.putFloat(i, dictionary.decodeToFloat(dictionaryIds.getInt(i)));
+ }
+ break;
+
case DOUBLE:
- case BINARY:
- column.setDictionary(dictionary);
+ for (int i = rowId; i < rowId + num; ++i) {
+ column.putDouble(i, dictionary.decodeToDouble(dictionaryIds.getInt(i)));
+ }
break;
case INT96:
if (column.dataType() == DataTypes.TimestampType) {
@@ -263,6 +265,16 @@ public class VectorizedColumnReader {
throw new NotImplementedException();
}
break;
+ case BINARY:
+ // TODO: this is incredibly inefficient as it blows up the dictionary right here. We
+ // need to do this better. We should probably add the dictionary data to the ColumnVector
+ // and reuse it across batches. This should mean adding a ByteArray would just update
+ // the length and offset.
+ for (int i = rowId; i < rowId + num; ++i) {
+ Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
+ column.putByteArray(i, v.getBytes());
+ }
+ break;
case FIXED_LEN_BYTE_ARRAY:
// DecimalType written in the legacy mode
if (DecimalType.is32BitDecimalType(column.dataType())) {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java
index 5bfde55c3b..51bdf0f0f2 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.datasources.parquet;
import java.io.IOException;
+import java.util.Arrays;
import java.util.List;
import org.apache.hadoop.mapreduce.InputSplit;
@@ -30,7 +31,8 @@ import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils;
import org.apache.spark.sql.execution.vectorized.ColumnarBatch;
-import org.apache.spark.sql.types.*;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
/**
* A specialized RecordReader that reads into InternalRows or ColumnarBatches directly using the
@@ -99,20 +101,6 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa
private static final MemoryMode DEFAULT_MEMORY_MODE = MemoryMode.ON_HEAP;
/**
- * Tries to initialize the reader for this split. Returns true if this reader supports reading
- * this split and false otherwise.
- */
- public boolean tryInitialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext)
- throws IOException, InterruptedException {
- try {
- initialize(inputSplit, taskAttemptContext);
- return true;
- } catch (UnsupportedOperationException e) {
- return false;
- }
- }
-
- /**
* Implementation of RecordReader API.
*/
@Override
@@ -190,7 +178,7 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa
}
}
- columnarBatch = ColumnarBatch.allocate(batchSchema);
+ columnarBatch = ColumnarBatch.allocate(batchSchema, memMode);
if (partitionColumns != null) {
int partitionIdx = sparkSchema.fields().length;
for (int i = 0; i < partitionColumns.fields().length; i++) {
@@ -221,7 +209,7 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa
return columnarBatch;
}
- /**
+ /*
* Can be called before any rows are returned to enable returning columnar batches directly.
*/
public void enableReturningBatches() {
@@ -269,7 +257,8 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa
} else {
if (requestedSchema.getColumns().get(i).getMaxDefinitionLevel() == 0) {
// Column is missing in data but the required data is non-nullable. This file is invalid.
- throw new IOException("Required column is missing in data file. Col: " + colPath);
+ throw new IOException("Required column is missing in data file. Col: " +
+ Arrays.toString(colPath));
}
missingColumns[i] = true;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java
new file mode 100644
index 0000000000..69ce54390f
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.vectorized;
+
+import java.util.Arrays;
+
+import com.google.common.annotations.VisibleForTesting;
+
+import org.apache.spark.memory.MemoryMode;
+import org.apache.spark.sql.types.StructType;
+
+import static org.apache.spark.sql.types.DataTypes.LongType;
+
+/**
+ * This is an illustrative implementation of an append-only single-key/single value aggregate hash
+ * map that can act as a 'cache' for extremely fast key-value lookups while evaluating aggregates
+ * (and fall back to the `BytesToBytesMap` if a given key isn't found). This can be potentially
+ * 'codegened' in TungstenAggregate to speed up aggregates w/ key.
+ *
+ * It is backed by a power-of-2-sized array for index lookups and a columnar batch that stores the
+ * key-value pairs. The index lookups in the array rely on linear probing (with a small number of
+ * maximum tries) and use an inexpensive hash function which makes it really efficient for a
+ * majority of lookups. However, using linear probing and an inexpensive hash function also makes it
+ * less robust as compared to the `BytesToBytesMap` (especially for a large number of keys or even
+ * for certain distribution of keys) and requires us to fall back on the latter for correctness.
+ */
+public class AggregateHashMap {
+
+ private ColumnarBatch batch;
+ private int[] buckets;
+ private int numBuckets;
+ private int numRows = 0;
+ private int maxSteps = 3;
+
+ private static int DEFAULT_CAPACITY = 1 << 16;
+ private static double DEFAULT_LOAD_FACTOR = 0.25;
+ private static int DEFAULT_MAX_STEPS = 3;
+
+ public AggregateHashMap(StructType schema, int capacity, double loadFactor, int maxSteps) {
+
+ // We currently only support single key-value pair that are both longs
+ assert (schema.size() == 2 && schema.fields()[0].dataType() == LongType &&
+ schema.fields()[1].dataType() == LongType);
+
+ // capacity should be a power of 2
+ assert (capacity > 0 && ((capacity & (capacity - 1)) == 0));
+
+ this.maxSteps = maxSteps;
+ numBuckets = (int) (capacity / loadFactor);
+ batch = ColumnarBatch.allocate(schema, MemoryMode.ON_HEAP, capacity);
+ buckets = new int[numBuckets];
+ Arrays.fill(buckets, -1);
+ }
+
+ public AggregateHashMap(StructType schema) {
+ this(schema, DEFAULT_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_MAX_STEPS);
+ }
+
+ public ColumnarBatch.Row findOrInsert(long key) {
+ int idx = find(key);
+ if (idx != -1 && buckets[idx] == -1) {
+ batch.column(0).putLong(numRows, key);
+ batch.column(1).putLong(numRows, 0);
+ buckets[idx] = numRows++;
+ }
+ return batch.getRow(buckets[idx]);
+ }
+
+ @VisibleForTesting
+ public int find(long key) {
+ long h = hash(key);
+ int step = 0;
+ int idx = (int) h & (numBuckets - 1);
+ while (step < maxSteps) {
+ // Return bucket index if it's either an empty slot or already contains the key
+ if (buckets[idx] == -1) {
+ return idx;
+ } else if (equals(idx, key)) {
+ return idx;
+ }
+ idx = (idx + 1) & (numBuckets - 1);
+ step++;
+ }
+ // Didn't find it
+ return -1;
+ }
+
+ private long hash(long key) {
+ return key;
+ }
+
+ private boolean equals(int idx, long key1) {
+ return batch.column(0).getLong(buckets[idx]) == key1;
+ }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
index 74fa6323cc..ff1f6680a7 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
@@ -56,7 +56,7 @@ import org.apache.spark.unsafe.types.UTF8String;
*
* ColumnVectors are intended to be reused.
*/
-public abstract class ColumnVector {
+public abstract class ColumnVector implements AutoCloseable {
/**
* Allocates a column to store elements of `type` on or off heap.
* Capacity is the initial capacity of the vector and it will grow as necessary. Capacity is
@@ -566,6 +566,18 @@ public abstract class ColumnVector {
}
}
+
+ public final void putDecimal(int rowId, Decimal value, int precision) {
+ if (precision <= Decimal.MAX_INT_DIGITS()) {
+ putInt(rowId, value.toInt());
+ } else if (precision <= Decimal.MAX_LONG_DIGITS()) {
+ putLong(rowId, value.toLong());
+ } else {
+ BigInteger bigInteger = value.toJavaBigDecimal().unscaledValue();
+ putByteArray(rowId, bigInteger.toByteArray());
+ }
+ }
+
/**
* Returns the UTF8String for rowId.
*/
@@ -901,6 +913,11 @@ public abstract class ColumnVector {
}
/**
+ * Returns true if this column has a dictionary.
+ */
+ public boolean hasDictionary() { return this.dictionary != null; }
+
+ /**
* Reserve a integer column for ids of dictionary.
*/
public ColumnVector reserveDictionaryIds(int capacity) {
@@ -915,6 +932,13 @@ public abstract class ColumnVector {
}
/**
+ * Returns the underlying integer column for ids of dictionary.
+ */
+ public ColumnVector getDictionaryIds() {
+ return dictionaryIds;
+ }
+
+ /**
* Sets up the common state and also handles creating the child columns if this is a nested
* type.
*/
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
index 792e17911f..8cece73faa 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
@@ -16,6 +16,7 @@
*/
package org.apache.spark.sql.execution.vectorized;
+import java.math.BigDecimal;
import java.util.*;
import org.apache.commons.lang.NotImplementedException;
@@ -23,6 +24,7 @@ import org.apache.commons.lang.NotImplementedException;
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow;
+import org.apache.spark.sql.catalyst.expressions.MutableRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.MapData;
@@ -79,7 +81,7 @@ public final class ColumnarBatch {
/**
* Called to close all the columns in this batch. It is not valid to access the data after
- * calling this. This must be called at the end to clean up memory allcoations.
+ * calling this. This must be called at the end to clean up memory allocations.
*/
public void close() {
for (ColumnVector c: columns) {
@@ -91,7 +93,7 @@ public final class ColumnarBatch {
* Adapter class to interop with existing components that expect internal row. A lot of
* performance is lost with this translation.
*/
- public static final class Row extends InternalRow {
+ public static final class Row extends MutableRow {
protected int rowId;
private final ColumnarBatch parent;
private final int fixedLenRowSize;
@@ -232,6 +234,96 @@ public final class ColumnarBatch {
public Object get(int ordinal, DataType dataType) {
throw new NotImplementedException();
}
+
+ @Override
+ public void update(int ordinal, Object value) {
+ if (value == null) {
+ setNullAt(ordinal);
+ } else {
+ DataType dt = columns[ordinal].dataType();
+ if (dt instanceof BooleanType) {
+ setBoolean(ordinal, (boolean) value);
+ } else if (dt instanceof IntegerType) {
+ setInt(ordinal, (int) value);
+ } else if (dt instanceof ShortType) {
+ setShort(ordinal, (short) value);
+ } else if (dt instanceof LongType) {
+ setLong(ordinal, (long) value);
+ } else if (dt instanceof FloatType) {
+ setFloat(ordinal, (float) value);
+ } else if (dt instanceof DoubleType) {
+ setDouble(ordinal, (double) value);
+ } else if (dt instanceof DecimalType) {
+ DecimalType t = (DecimalType) dt;
+ setDecimal(ordinal, Decimal.apply((BigDecimal) value, t.precision(), t.scale()),
+ t.precision());
+ } else {
+ throw new NotImplementedException("Datatype not supported " + dt);
+ }
+ }
+ }
+
+ @Override
+ public void setNullAt(int ordinal) {
+ assert (!columns[ordinal].isConstant);
+ columns[ordinal].putNull(rowId);
+ }
+
+ @Override
+ public void setBoolean(int ordinal, boolean value) {
+ assert (!columns[ordinal].isConstant);
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putBoolean(rowId, value);
+ }
+
+ @Override
+ public void setByte(int ordinal, byte value) {
+ assert (!columns[ordinal].isConstant);
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putByte(rowId, value);
+ }
+
+ @Override
+ public void setShort(int ordinal, short value) {
+ assert (!columns[ordinal].isConstant);
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putShort(rowId, value);
+ }
+
+ @Override
+ public void setInt(int ordinal, int value) {
+ assert (!columns[ordinal].isConstant);
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putInt(rowId, value);
+ }
+
+ @Override
+ public void setLong(int ordinal, long value) {
+ assert (!columns[ordinal].isConstant);
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putLong(rowId, value);
+ }
+
+ @Override
+ public void setFloat(int ordinal, float value) {
+ assert (!columns[ordinal].isConstant);
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putFloat(rowId, value);
+ }
+
+ @Override
+ public void setDouble(int ordinal, double value) {
+ assert (!columns[ordinal].isConstant);
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putDouble(rowId, value);
+ }
+
+ @Override
+ public void setDecimal(int ordinal, Decimal value, int precision) {
+ assert (!columns[ordinal].isConstant);
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putDecimal(rowId, value, precision);
+ }
}
/**
@@ -315,7 +407,7 @@ public final class ColumnarBatch {
public int numRows() { return numRows; }
/**
- * Returns the number of valid rowss.
+ * Returns the number of valid rows.
*/
public int numValidRows() {
assert(numRowsFiltered <= numRows);
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
index b1429fe7cb..e97276800d 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
@@ -212,7 +212,7 @@ public final class OnHeapColumnVector extends ColumnVector {
public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) {
int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET;
for (int i = 0; i < count; ++i) {
- intData[i + rowId] = Platform.getInt(src, srcOffset);;
+ intData[i + rowId] = Platform.getInt(src, srcOffset);
srcIndex += 4;
srcOffset += 4;
}
@@ -387,35 +387,49 @@ public final class OnHeapColumnVector extends ColumnVector {
arrayLengths = newLengths;
arrayOffsets = newOffsets;
} else if (type instanceof BooleanType) {
- byte[] newData = new byte[newCapacity];
- if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended);
- byteData = newData;
+ if (byteData == null || byteData.length < newCapacity) {
+ byte[] newData = new byte[newCapacity];
+ if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended);
+ byteData = newData;
+ }
} else if (type instanceof ByteType) {
- byte[] newData = new byte[newCapacity];
- if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended);
- byteData = newData;
+ if (byteData == null || byteData.length < newCapacity) {
+ byte[] newData = new byte[newCapacity];
+ if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended);
+ byteData = newData;
+ }
} else if (type instanceof ShortType) {
- short[] newData = new short[newCapacity];
- if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended);
- shortData = newData;
+ if (shortData == null || shortData.length < newCapacity) {
+ short[] newData = new short[newCapacity];
+ if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended);
+ shortData = newData;
+ }
} else if (type instanceof IntegerType || type instanceof DateType ||
DecimalType.is32BitDecimalType(type)) {
- int[] newData = new int[newCapacity];
- if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended);
- intData = newData;
+ if (intData == null || intData.length < newCapacity) {
+ int[] newData = new int[newCapacity];
+ if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended);
+ intData = newData;
+ }
} else if (type instanceof LongType || type instanceof TimestampType ||
DecimalType.is64BitDecimalType(type)) {
- long[] newData = new long[newCapacity];
- if (longData != null) System.arraycopy(longData, 0, newData, 0, elementsAppended);
- longData = newData;
+ if (longData == null || longData.length < newCapacity) {
+ long[] newData = new long[newCapacity];
+ if (longData != null) System.arraycopy(longData, 0, newData, 0, elementsAppended);
+ longData = newData;
+ }
} else if (type instanceof FloatType) {
- float[] newData = new float[newCapacity];
- if (floatData != null) System.arraycopy(floatData, 0, newData, 0, elementsAppended);
- floatData = newData;
+ if (floatData == null || floatData.length < newCapacity) {
+ float[] newData = new float[newCapacity];
+ if (floatData != null) System.arraycopy(floatData, 0, newData, 0, elementsAppended);
+ floatData = newData;
+ }
} else if (type instanceof DoubleType) {
- double[] newData = new double[newCapacity];
- if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, elementsAppended);
- doubleData = newData;
+ if (doubleData == null || doubleData.length < newCapacity) {
+ double[] newData = new double[newCapacity];
+ if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, elementsAppended);
+ doubleData = newData;
+ }
} else if (resultStruct != null) {
// Nothing to store.
} else {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/expressions/java/typed.java b/sql/core/src/main/java/org/apache/spark/sql/expressions/java/typed.java
new file mode 100644
index 0000000000..c7c6e3868f
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/expressions/java/typed.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions.java;
+
+import org.apache.spark.annotation.Experimental;
+import org.apache.spark.api.java.function.MapFunction;
+import org.apache.spark.sql.TypedColumn;
+import org.apache.spark.sql.execution.aggregate.TypedAverage;
+import org.apache.spark.sql.execution.aggregate.TypedCount;
+import org.apache.spark.sql.execution.aggregate.TypedSumDouble;
+import org.apache.spark.sql.execution.aggregate.TypedSumLong;
+
+/**
+ * :: Experimental ::
+ * Type-safe functions available for {@link org.apache.spark.sql.Dataset} operations in Java.
+ *
+ * Scala users should use {@link org.apache.spark.sql.expressions.scala.typed}.
+ *
+ * @since 2.0.0
+ */
+@Experimental
+public class typed {
+ // Note: make sure to keep in sync with typed.scala
+
+ /**
+ * Average aggregate function.
+ *
+ * @since 2.0.0
+ */
+ public static <T> TypedColumn<T, Double> avg(MapFunction<T, Double> f) {
+ return new TypedAverage<T>(f).toColumnJava();
+ }
+
+ /**
+ * Count aggregate function.
+ *
+ * @since 2.0.0
+ */
+ public static <T> TypedColumn<T, Long> count(MapFunction<T, Object> f) {
+ return new TypedCount<T>(f).toColumnJava();
+ }
+
+ /**
+ * Sum aggregate function for floating point (double) type.
+ *
+ * @since 2.0.0
+ */
+ public static <T> TypedColumn<T, Double> sum(MapFunction<T, Double> f) {
+ return new TypedSumDouble<T>(f).toColumnJava();
+ }
+
+ /**
+ * Sum aggregate function for integral (long, i.e. 64 bit integer) type.
+ *
+ * @since 2.0.0
+ */
+ public static <T> TypedColumn<T, Long> sumLong(MapFunction<T, Long> f) {
+ return new TypedSumLong<T>(f).toColumnJava();
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala
index 1dc9a6893e..d9973b092d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala
@@ -94,7 +94,7 @@ trait ContinuousQuery {
/**
* Blocks until all available data in the source has been processed an committed to the sink.
* This method is intended for testing. Note that in the case of continually arriving data, this
- * method may block forever. Additionally, this method is only guranteed to block until data that
+ * method may block forever. Additionally, this method is only guaranteed to block until data that
* has been synchronously appended data to a [[org.apache.spark.sql.execution.streaming.Source]]
* prior to invocation. (i.e. `getOffset` must immediately reflect the addition).
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
index 465feeb604..1343e81569 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
import scala.collection.mutable
import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql.execution.streaming.{ContinuousQueryListenerBus, Sink, StreamExecution}
+import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef
import org.apache.spark.sql.util.ContinuousQueryListener
@@ -171,13 +171,31 @@ class ContinuousQueryManager(sqlContext: SQLContext) {
name: String,
checkpointLocation: String,
df: DataFrame,
- sink: Sink): ContinuousQuery = {
+ sink: Sink,
+ trigger: Trigger = ProcessingTime(0)): ContinuousQuery = {
activeQueriesLock.synchronized {
if (activeQueries.contains(name)) {
throw new IllegalArgumentException(
s"Cannot start query with name $name as a query with that name is already active")
}
- val query = new StreamExecution(sqlContext, name, checkpointLocation, df.logicalPlan, sink)
+ var nextSourceId = 0L
+ val logicalPlan = df.logicalPlan.transform {
+ case StreamingRelation(dataSource, _, output) =>
+ // Materialize source to avoid creating it in every batch
+ val metadataPath = s"$checkpointLocation/sources/$nextSourceId"
+ val source = dataSource.createSource(metadataPath)
+ nextSourceId += 1
+ // We still need to use the previous `output` instead of `source.schema` as attributes in
+ // "df.logicalPlan" has already used attributes of the previous `output`.
+ StreamingExecutionRelation(source, output)
+ }
+ val query = new StreamExecution(
+ sqlContext,
+ name,
+ checkpointLocation,
+ logicalPlan,
+ sink,
+ trigger)
query.start()
activeQueries.put(name, query)
query
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index 33588ef72f..f0e16eefc7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -200,6 +200,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* The key of the map is the column name, and the value of the map is the replacement value.
* The value must be of the following type:
* `Integer`, `Long`, `Float`, `Double`, `String`, `Boolean`.
+ * Replacement values are cast to the column data type.
*
* For example, the following replaces null values in column "A" with string "unknown", and
* null values in column "B" with numeric value 1.0.
@@ -217,6 +218,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* The key of the map is the column name, and the value of the map is the replacement value.
* The value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`, `Boolean`.
+ * Replacement values are cast to the column data type.
*
* For example, the following replaces null values in column "A" with string "unknown", and
* null values in column "B" with numeric value 1.0.
@@ -386,10 +388,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
val projections = df.schema.fields.map { f =>
values.find { case (k, _) => columnEquals(k, f.name) }.map { case (_, v) =>
v match {
- case v: jl.Float => fillCol[Double](f, v.toDouble)
+ case v: jl.Float => fillCol[Float](f, v)
case v: jl.Double => fillCol[Double](f, v)
- case v: jl.Long => fillCol[Double](f, v.toDouble)
- case v: jl.Integer => fillCol[Double](f, v.toDouble)
+ case v: jl.Long => fillCol[Long](f, v)
+ case v: jl.Integer => fillCol[Integer](f, v)
case v: jl.Boolean => fillCol[Boolean](f, v.booleanValue())
case v: String => fillCol[String](f, v)
}
@@ -402,13 +404,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* Returns a [[Column]] expression that replaces null value in `col` with `replacement`.
*/
private def fillCol[T](col: StructField, replacement: T): Column = {
- col.dataType match {
+ val quotedColName = "`" + col.name + "`"
+ val colValue = col.dataType match {
case DoubleType | FloatType =>
- coalesce(nanvl(df.col("`" + col.name + "`"), lit(null)),
- lit(replacement).cast(col.dataType)).as(col.name)
- case _ =>
- coalesce(df.col("`" + col.name + "`"), lit(replacement).cast(col.dataType)).as(col.name)
+ nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types
+ case _ => df.col(quotedColName)
}
+ coalesce(colValue, lit(replacement)).cast(col.dataType).as(col.name)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 704535adaa..15f2344df6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -176,7 +176,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
userSpecifiedSchema = userSpecifiedSchema,
className = source,
options = extraOptions.toMap)
- Dataset.ofRows(sqlContext, StreamingRelation(dataSource.createSource()))
+ Dataset.ofRows(sqlContext, StreamingRelation(dataSource))
}
/**
@@ -315,8 +315,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
*
* You can set the following JSON-specific options to deal with non-standard JSON files:
* <li>`primitivesAsString` (default `false`): infers all primitive values as a string type</li>
- * <li>`floatAsBigDecimal` (default `false`): infers all floating-point values as a decimal
- * type</li>
+ * <li>`prefersDecimal` (default `false`): infers all floating-point values as a decimal
+ * type. If the values do not fit in decimal, then it infers them as doubles.</li>
* <li>`allowComments` (default `false`): ignores Java/C++ style comment in JSON records</li>
* <li>`allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names</li>
* <li>`allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index c07bd0e7b7..54d250867f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -29,8 +29,10 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project}
import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, DataSource}
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
-import org.apache.spark.sql.execution.streaming.StreamExecution
+import org.apache.spark.sql.execution.streaming.{MemoryPlan, MemorySink, StreamExecution}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.HadoopFsRelation
+import org.apache.spark.util.Utils
/**
* :: Experimental ::
@@ -78,6 +80,35 @@ final class DataFrameWriter private[sql](df: DataFrame) {
}
/**
+ * :: Experimental ::
+ * Set the trigger for the stream query. The default value is `ProcessingTime(0)` and it will run
+ * the query as fast as possible.
+ *
+ * Scala Example:
+ * {{{
+ * def.writer.trigger(ProcessingTime("10 seconds"))
+ *
+ * import scala.concurrent.duration._
+ * def.writer.trigger(ProcessingTime(10.seconds))
+ * }}}
+ *
+ * Java Example:
+ * {{{
+ * def.writer.trigger(ProcessingTime.create("10 seconds"))
+ *
+ * import java.util.concurrent.TimeUnit
+ * def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS))
+ * }}}
+ *
+ * @since 2.0.0
+ */
+ @Experimental
+ def trigger(trigger: Trigger): DataFrameWriter = {
+ this.trigger = trigger
+ this
+ }
+
+ /**
* Specifies the underlying output data source. Built-in options include "parquet", "json", etc.
*
* @since 1.4.0
@@ -246,22 +277,64 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0.0
*/
def startStream(): ContinuousQuery = {
- val dataSource =
- DataSource(
- df.sqlContext,
- className = source,
- options = extraOptions.toMap,
- partitionColumns = normalizedParCols.getOrElse(Nil))
-
- val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName)
- val checkpointLocation = extraOptions.getOrElse("checkpointLocation", {
- new Path(df.sqlContext.conf.checkpointLocation, queryName).toUri.toString
- })
- df.sqlContext.sessionState.continuousQueryManager.startQuery(
- queryName,
- checkpointLocation,
- df,
- dataSource.createSink())
+ if (source == "memory") {
+ val queryName =
+ extraOptions.getOrElse(
+ "queryName", throw new AnalysisException("queryName must be specified for memory sink"))
+ val checkpointLocation = extraOptions.get("checkpointLocation").map { userSpecified =>
+ new Path(userSpecified).toUri.toString
+ }.orElse {
+ val checkpointConfig: Option[String] =
+ df.sqlContext.conf.getConf(
+ SQLConf.CHECKPOINT_LOCATION,
+ None)
+
+ checkpointConfig.map { location =>
+ new Path(location, queryName).toUri.toString
+ }
+ }.getOrElse {
+ Utils.createTempDir(namePrefix = "memory.stream").getCanonicalPath
+ }
+
+ // If offsets have already been created, we trying to resume a query.
+ val checkpointPath = new Path(checkpointLocation, "offsets")
+ val fs = checkpointPath.getFileSystem(df.sqlContext.sparkContext.hadoopConfiguration)
+ if (fs.exists(checkpointPath)) {
+ throw new AnalysisException(
+ s"Unable to resume query written to memory sink. Delete $checkpointPath to start over.")
+ } else {
+ checkpointPath.toUri.toString
+ }
+
+ val sink = new MemorySink(df.schema)
+ val resultDf = Dataset.ofRows(df.sqlContext, new MemoryPlan(sink))
+ resultDf.registerTempTable(queryName)
+ val continuousQuery = df.sqlContext.sessionState.continuousQueryManager.startQuery(
+ queryName,
+ checkpointLocation,
+ df,
+ sink,
+ trigger)
+ continuousQuery
+ } else {
+ val dataSource =
+ DataSource(
+ df.sqlContext,
+ className = source,
+ options = extraOptions.toMap,
+ partitionColumns = normalizedParCols.getOrElse(Nil))
+
+ val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName)
+ val checkpointLocation = extraOptions.getOrElse("checkpointLocation", {
+ new Path(df.sqlContext.conf.checkpointLocation, queryName).toUri.toString
+ })
+ df.sqlContext.sessionState.continuousQueryManager.startQuery(
+ queryName,
+ checkpointLocation,
+ df,
+ dataSource.createSink(),
+ trigger)
+ }
}
/**
@@ -552,6 +625,8 @@ final class DataFrameWriter private[sql](df: DataFrame) {
private var mode: SaveMode = SaveMode.ErrorIfExists
+ private var trigger: Trigger = ProcessingTime(0L)
+
private var extraOptions = new scala.collection.mutable.HashMap[String, String]
private var partitioningColumns: Option[Seq[String]] = None
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 703ea4d149..e216945fbe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -22,8 +22,10 @@ import java.io.CharArrayWriter
import scala.collection.JavaConverters._
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
+import scala.util.control.NonFatal
import com.fasterxml.jackson.core.JsonFactory
+import org.apache.commons.lang3.StringUtils
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
@@ -39,11 +41,12 @@ import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.usePrettyExpression
-import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution}
+import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, QueryExecution, SQLExecution}
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation}
import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
import org.apache.spark.sql.execution.python.EvaluatePython
+import org.apache.spark.sql.execution.streaming.{StreamingExecutionRelation, StreamingRelation}
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@@ -150,10 +153,10 @@ private[sql] object Dataset {
* @since 1.6.0
*/
class Dataset[T] private[sql](
- @transient override val sqlContext: SQLContext,
- @DeveloperApi @transient override val queryExecution: QueryExecution,
+ @transient val sqlContext: SQLContext,
+ @DeveloperApi @transient val queryExecution: QueryExecution,
encoder: Encoder[T])
- extends Queryable with Serializable {
+ extends Serializable {
queryExecution.assertAnalyzed()
@@ -224,7 +227,7 @@ class Dataset[T] private[sql](
* @param _numRows Number of rows to show
* @param truncate Whether truncate long strings and align cells right
*/
- override private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = {
+ private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = {
val numRows = _numRows.max(0)
val takeResult = take(numRows + 1)
val hasMoreData = takeResult.length > numRows
@@ -249,7 +252,75 @@ class Dataset[T] private[sql](
}: Seq[String]
}
- formatString ( rows, numRows, hasMoreData, truncate )
+ val sb = new StringBuilder
+ val numCols = schema.fieldNames.length
+
+ // Initialise the width of each column to a minimum value of '3'
+ val colWidths = Array.fill(numCols)(3)
+
+ // Compute the width of each column
+ for (row <- rows) {
+ for ((cell, i) <- row.zipWithIndex) {
+ colWidths(i) = math.max(colWidths(i), cell.length)
+ }
+ }
+
+ // Create SeparateLine
+ val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString()
+
+ // column names
+ rows.head.zipWithIndex.map { case (cell, i) =>
+ if (truncate) {
+ StringUtils.leftPad(cell, colWidths(i))
+ } else {
+ StringUtils.rightPad(cell, colWidths(i))
+ }
+ }.addString(sb, "|", "|", "|\n")
+
+ sb.append(sep)
+
+ // data
+ rows.tail.map {
+ _.zipWithIndex.map { case (cell, i) =>
+ if (truncate) {
+ StringUtils.leftPad(cell.toString, colWidths(i))
+ } else {
+ StringUtils.rightPad(cell.toString, colWidths(i))
+ }
+ }.addString(sb, "|", "|", "|\n")
+ }
+
+ sb.append(sep)
+
+ // For Data that has more than "numRows" records
+ if (hasMoreData) {
+ val rowsString = if (numRows == 1) "row" else "rows"
+ sb.append(s"only showing top $numRows $rowsString\n")
+ }
+
+ sb.toString()
+ }
+
+ override def toString: String = {
+ try {
+ val builder = new StringBuilder
+ val fields = schema.take(2).map {
+ case f => s"${f.name}: ${f.dataType.simpleString(2)}"
+ }
+ builder.append("[")
+ builder.append(fields.mkString(", "))
+ if (schema.length > 2) {
+ if (schema.length - fields.size == 1) {
+ builder.append(" ... 1 more field")
+ } else {
+ builder.append(" ... " + (schema.length - 2) + " more fields")
+ }
+ }
+ builder.append("]").toString()
+ } catch {
+ case NonFatal(e) =>
+ s"Invalid tree; ${e.getMessage}:\n$queryExecution"
+ }
}
/**
@@ -325,7 +396,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
// scalastyle:off println
- override def printSchema(): Unit = println(schema.treeString)
+ def printSchema(): Unit = println(schema.treeString)
// scalastyle:on println
/**
@@ -334,7 +405,7 @@ class Dataset[T] private[sql](
* @group basic
* @since 1.6.0
*/
- override def explain(extended: Boolean): Unit = {
+ def explain(extended: Boolean): Unit = {
val explain = ExplainCommand(queryExecution.logical, extended = extended)
sqlContext.executePlan(explain).executedPlan.executeCollect().foreach {
// scalastyle:off println
@@ -349,7 +420,7 @@ class Dataset[T] private[sql](
* @group basic
* @since 1.6.0
*/
- override def explain(): Unit = explain(extended = false)
+ def explain(): Unit = explain(extended = false)
/**
* Returns all column names and their data types as an array.
@@ -379,6 +450,22 @@ class Dataset[T] private[sql](
def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation]
/**
+ * Returns true if this [[Dataset]] contains one or more sources that continuously
+ * return data as it arrives. A [[Dataset]] that reads data from a streaming source
+ * must be executed as a [[ContinuousQuery]] using the `startStream()` method in
+ * [[DataFrameWriter]]. Methods that return a single answer, (e.g., `count()` or
+ * `collect()`) will throw an [[AnalysisException]] when there is a streaming
+ * source present.
+ *
+ * @group basic
+ * @since 2.0.0
+ */
+ @Experimental
+ def isStreaming: Boolean = logicalPlan.find { n =>
+ n.isInstanceOf[StreamingRelation] || n.isInstanceOf[StreamingExecutionRelation]
+ }.isDefined
+
+ /**
* Displays the [[Dataset]] in a tabular form. Strings more than 20 characters will be truncated,
* and all cells will be aligned right. For example:
* {{{
@@ -678,7 +765,8 @@ class Dataset[T] private[sql](
implicit val tuple2Encoder: Encoder[(T, U)] =
ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder)
- withTypedPlan[(T, U)](other, encoderFor[(T, U)]) { (left, right) =>
+
+ withTypedPlan {
Project(
leftData :: rightData :: Nil,
joined.analyzed)
@@ -1404,6 +1492,8 @@ class Dataset[T] private[sql](
* @param weights weights for splits, will be normalized if they don't sum to 1.
* @param seed Seed for sampling.
*
+ * For Java API, use [[randomSplitAsList]].
+ *
* @group typedrel
* @since 2.0.0
*/
@@ -1422,6 +1512,20 @@ class Dataset[T] private[sql](
}
/**
+ * Returns a Java list that contains randomly split [[Dataset]] with the provided weights.
+ *
+ * @param weights weights for splits, will be normalized if they don't sum to 1.
+ * @param seed Seed for sampling.
+ *
+ * @group typedrel
+ * @since 2.0.0
+ */
+ def randomSplitAsList(weights: Array[Double], seed: Long): java.util.List[Dataset[T]] = {
+ val values = randomSplit(weights, seed)
+ java.util.Arrays.asList(values : _*)
+ }
+
+ /**
* Randomly splits this [[Dataset]] with the provided weights.
*
* @param weights weights for splits, will be normalized if they don't sum to 1.
@@ -1790,7 +1894,13 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func))
+ def filter(func: T => Boolean): Dataset[T] = {
+ val deserialized = CatalystSerde.deserialize[T](logicalPlan)
+ val function = Literal.create(func, ObjectType(classOf[T => Boolean]))
+ val condition = Invoke(function, "apply", BooleanType, deserialized.output)
+ val filter = Filter(condition, deserialized)
+ withTypedPlan(CatalystSerde.serialize[T](filter))
+ }
/**
* :: Experimental ::
@@ -1801,7 +1911,13 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t))
+ def filter(func: FilterFunction[T]): Dataset[T] = {
+ val deserialized = CatalystSerde.deserialize[T](logicalPlan)
+ val function = Literal.create(func, ObjectType(classOf[FilterFunction[T]]))
+ val condition = Invoke(function, "call", BooleanType, deserialized.output)
+ val filter = Filter(condition, deserialized)
+ withTypedPlan(CatalystSerde.serialize[T](filter))
+ }
/**
* :: Experimental ::
@@ -1812,7 +1928,9 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func))
+ def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan {
+ MapElements[T, U](func, logicalPlan)
+ }
/**
* :: Experimental ::
@@ -1823,8 +1941,10 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] =
- map(t => func.call(t))(encoder)
+ def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
+ implicit val uEnc = encoder
+ withTypedPlan(MapElements[T, U](func, logicalPlan))
+ }
/**
* :: Experimental ::
@@ -1987,6 +2107,24 @@ class Dataset[T] private[sql](
}
/**
+ * Return an iterator that contains all of [[Row]]s in this [[Dataset]].
+ *
+ * The iterator will consume as much memory as the largest partition in this [[Dataset]].
+ *
+ * Note: this results in multiple Spark jobs, and if the input Dataset is the result
+ * of a wide transformation (e.g. join with different partitioners), to avoid
+ * recomputing the input Dataset should be cached first.
+ *
+ * @group action
+ * @since 2.0.0
+ */
+ def toLocalIterator(): java.util.Iterator[T] = withCallback("toLocalIterator", toDF()) { _ =>
+ withNewExecutionId {
+ queryExecution.executedPlan.executeToIterator().map(boundTEncoder.fromRow).asJava
+ }
+ }
+
+ /**
* Returns the number of rows in the [[Dataset]].
* @group action
* @since 1.6.0
@@ -2007,7 +2145,7 @@ class Dataset[T] private[sql](
/**
* Returns a new [[Dataset]] partitioned by the given partitioning expressions into
- * `numPartitions`. The resulting Datasetis hash partitioned.
+ * `numPartitions`. The resulting Dataset is hash partitioned.
*
* This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL).
*
@@ -2230,6 +2368,12 @@ class Dataset[T] private[sql](
}
}
+ protected[sql] def toPythonIterator(): Int = {
+ withNewExecutionId {
+ PythonRDD.toLocalIteratorAndServe(javaToPython.rdd)
+ }
+ }
+
////////////////////////////////////////////////////////////////////////////
// Private Helpers
////////////////////////////////////////////////////////////////////////////
@@ -2300,12 +2444,7 @@ class Dataset[T] private[sql](
}
/** A convenient function to wrap a logical plan and produce a Dataset. */
- @inline private def withTypedPlan(logicalPlan: => LogicalPlan): Dataset[T] = {
- new Dataset[T](sqlContext, logicalPlan, encoder)
+ @inline private def withTypedPlan[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = {
+ Dataset(sqlContext, logicalPlan)
}
-
- private[sql] def withTypedPlan[R](
- other: Dataset[_], encoder: Encoder[R])(
- f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] =
- new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan), encoder)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
index d7cd84fd24..c5df028485 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
@@ -37,7 +37,7 @@ class ExperimentalMethods private[sql]() {
/**
* Allows extra strategies to be injected into the query planner at runtime. Note this API
- * should be consider experimental and is not intended to be stable across releases.
+ * should be considered experimental and is not intended to be stable across releases.
*
* @since 1.3.0
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 07aa1515f3..f19ad6e707 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -57,13 +57,6 @@ class KeyValueGroupedDataset[K, V] private[sql](
private def logicalPlan = queryExecution.analyzed
private def sqlContext = queryExecution.sqlContext
- private def groupedData = {
- new RelationalGroupedDataset(
- Dataset.ofRows(sqlContext, logicalPlan),
- groupingAttributes,
- RelationalGroupedDataset.GroupByType)
- }
-
/**
* Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the
* specified type. The mapping of key columns to the type follows the same rules as `as` on
@@ -207,12 +200,6 @@ class KeyValueGroupedDataset[K, V] private[sql](
reduceGroups(f.call _)
}
- private def withEncoder(c: Column): Column = c match {
- case tc: TypedColumn[_, _] =>
- tc.withInputType(resolvedVEncoder.bind(dataAttributes), dataAttributes)
- case _ => c
- }
-
/**
* Internal helper function for building typed aggregations that return tuples. For simplicity
* and code reuse, we do this without the help of the type system and then use helper functions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 91c02053ae..7dbf2e6c7c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -408,7 +408,7 @@ private[sql] object RelationalGroupedDataset {
private[sql] object RollupType extends GroupType
/**
- * To indicate it's the PIVOT
- */
+ * To indicate it's the PIVOT
+ */
private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index e413e77bc1..9259ff4062 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -29,10 +29,11 @@ import org.apache.spark.{SparkConf, SparkContext, SparkException}
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.ConfigEntry
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.catalyst._
-import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, InMemoryCatalog}
+import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range}
@@ -41,7 +42,6 @@ import org.apache.spark.sql.execution.command.ShowTablesCommand
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
import org.apache.spark.sql.internal.{SessionState, SQLConf}
-import org.apache.spark.sql.internal.SQLConf.SQLConfEntry
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ExecutionListenerManager
@@ -120,7 +120,7 @@ class SQLContext private[sql](
*/
@transient
protected[sql] lazy val sessionState: SessionState = new SessionState(self)
- protected[sql] def conf: SQLConf = sessionState.conf
+ protected[spark] def conf: SQLConf = sessionState.conf
/**
* An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
@@ -138,7 +138,7 @@ class SQLContext private[sql](
def setConf(props: Properties): Unit = conf.setConf(props)
/** Set the given Spark SQL configuration property. */
- private[sql] def setConf[T](entry: SQLConfEntry[T], value: T): Unit = conf.setConf(entry, value)
+ private[sql] def setConf[T](entry: ConfigEntry[T], value: T): Unit = conf.setConf(entry, value)
/**
* Set the given Spark SQL configuration property.
@@ -158,16 +158,16 @@ class SQLContext private[sql](
/**
* Return the value of Spark SQL configuration property for the given key. If the key is not set
- * yet, return `defaultValue` in [[SQLConfEntry]].
+ * yet, return `defaultValue` in [[ConfigEntry]].
*/
- private[sql] def getConf[T](entry: SQLConfEntry[T]): T = conf.getConf(entry)
+ private[sql] def getConf[T](entry: ConfigEntry[T]): T = conf.getConf(entry)
/**
* Return the value of Spark SQL configuration property for the given key. If the key is not set
- * yet, return `defaultValue`. This is useful when `defaultValue` in SQLConfEntry is not the
+ * yet, return `defaultValue`. This is useful when `defaultValue` in ConfigEntry is not the
* desired one.
*/
- private[sql] def getConf[T](entry: SQLConfEntry[T], defaultValue: T): T = {
+ private[sql] def getConf[T](entry: ConfigEntry[T], defaultValue: T): T = {
conf.getConf(entry, defaultValue)
}
@@ -208,6 +208,22 @@ class SQLContext private[sql](
sparkContext.addJar(path)
}
+ /** A [[FunctionResourceLoader]] that can be used in SessionCatalog. */
+ @transient protected[sql] lazy val functionResourceLoader: FunctionResourceLoader = {
+ new FunctionResourceLoader {
+ override def loadResource(resource: FunctionResource): Unit = {
+ resource.resourceType match {
+ case JarResource => addJar(resource.uri)
+ case FileResource => sparkContext.addFile(resource.uri)
+ case ArchiveResource =>
+ throw new AnalysisException(
+ "Archive is not allowed to be loaded. If YARN mode is used, " +
+ "please use --archives options while calling spark-submit.")
+ }
+ }
+ }
+ }
+
/**
* :: Experimental ::
* A collection of methods that are considered experimental, but can be used to hook into
@@ -272,11 +288,11 @@ class SQLContext private[sql](
}
/**
- * Returns true if the [[Queryable]] is currently cached in-memory.
+ * Returns true if the [[Dataset]] is currently cached in-memory.
* @group cachemgmt
* @since 1.3.0
*/
- private[sql] def isCached(qName: Queryable): Boolean = {
+ private[sql] def isCached(qName: Dataset[_]): Boolean = {
cacheManager.lookupCachedData(qName).nonEmpty
}
@@ -671,7 +687,7 @@ class SQLContext private[sql](
sessionState.catalog.createTempTable(
sessionState.sqlParser.parseTableIdentifier(tableName).table,
df.logicalPlan,
- ignoreIfExists = true)
+ overrideIfExists = true)
}
/**
@@ -712,13 +728,13 @@ class SQLContext private[sql](
}
/**
- * :: Experimental ::
- * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements
- * in an range from `start` to `end` (exclusive) with an step value.
- *
- * @since 2.0.0
- * @group dataset
- */
+ * :: Experimental ::
+ * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements
+ * in an range from `start` to `end` (exclusive) with an step value.
+ *
+ * @since 2.0.0
+ * @group dataset
+ */
@Experimental
def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = {
range(start, end, step, numPartitions = sparkContext.defaultParallelism)
@@ -781,7 +797,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def tables(): DataFrame = {
- Dataset.ofRows(this, ShowTablesCommand(None))
+ Dataset.ofRows(this, ShowTablesCommand(None, None))
}
/**
@@ -793,7 +809,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def tables(databaseName: String): DataFrame = {
- Dataset.ofRows(this, ShowTablesCommand(Some(databaseName)))
+ Dataset.ofRows(this, ShowTablesCommand(Some(databaseName), None))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index c35a969bf0..ad69e23540 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -44,33 +44,33 @@ abstract class SQLImplicits {
}
/** @since 1.6.0 */
- implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder()
+ implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T]
// Primitives
/** @since 1.6.0 */
- implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder()
+ implicit def newIntEncoder: Encoder[Int] = Encoders.scalaInt
/** @since 1.6.0 */
- implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder()
+ implicit def newLongEncoder: Encoder[Long] = Encoders.scalaLong
/** @since 1.6.0 */
- implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder()
+ implicit def newDoubleEncoder: Encoder[Double] = Encoders.scalaDouble
/** @since 1.6.0 */
- implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder()
+ implicit def newFloatEncoder: Encoder[Float] = Encoders.scalaFloat
/** @since 1.6.0 */
- implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder()
+ implicit def newByteEncoder: Encoder[Byte] = Encoders.scalaByte
/** @since 1.6.0 */
- implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder()
+ implicit def newShortEncoder: Encoder[Short] = Encoders.scalaShort
/** @since 1.6.0 */
- implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder()
+ implicit def newBooleanEncoder: Encoder[Boolean] = Encoders.scalaBoolean
/** @since 1.6.0 */
- implicit def newStringEncoder: Encoder[String] = ExpressionEncoder()
+ implicit def newStringEncoder: Encoder[String] = Encoders.STRING
// Seqs
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala b/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala
new file mode 100644
index 0000000000..c4e54b3f90
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.concurrent.TimeUnit
+
+import scala.concurrent.duration.Duration
+
+import org.apache.commons.lang3.StringUtils
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.unsafe.types.CalendarInterval
+
+/**
+ * :: Experimental ::
+ * Used to indicate how often results should be produced by a [[ContinuousQuery]].
+ */
+@Experimental
+sealed trait Trigger {}
+
+/**
+ * :: Experimental ::
+ * A trigger that runs a query periodically based on the processing time. If `intervalMs` is 0,
+ * the query will run as fast as possible.
+ *
+ * Scala Example:
+ * {{{
+ * def.writer.trigger(ProcessingTime("10 seconds"))
+ *
+ * import scala.concurrent.duration._
+ * def.writer.trigger(ProcessingTime(10.seconds))
+ * }}}
+ *
+ * Java Example:
+ * {{{
+ * def.writer.trigger(ProcessingTime.create("10 seconds"))
+ *
+ * import java.util.concurrent.TimeUnit
+ * def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS))
+ * }}}
+ */
+@Experimental
+case class ProcessingTime(intervalMs: Long) extends Trigger {
+ require(intervalMs >= 0, "the interval of trigger should not be negative")
+}
+
+/**
+ * :: Experimental ::
+ * Used to create [[ProcessingTime]] triggers for [[ContinuousQuery]]s.
+ */
+@Experimental
+object ProcessingTime {
+
+ /**
+ * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible.
+ *
+ * Example:
+ * {{{
+ * def.writer.trigger(ProcessingTime("10 seconds"))
+ * }}}
+ */
+ def apply(interval: String): ProcessingTime = {
+ if (StringUtils.isBlank(interval)) {
+ throw new IllegalArgumentException(
+ "interval cannot be null or blank.")
+ }
+ val cal = if (interval.startsWith("interval")) {
+ CalendarInterval.fromString(interval)
+ } else {
+ CalendarInterval.fromString("interval " + interval)
+ }
+ if (cal == null) {
+ throw new IllegalArgumentException(s"Invalid interval: $interval")
+ }
+ if (cal.months > 0) {
+ throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval")
+ }
+ new ProcessingTime(cal.microseconds / 1000)
+ }
+
+ /**
+ * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible.
+ *
+ * Example:
+ * {{{
+ * import scala.concurrent.duration._
+ * def.writer.trigger(ProcessingTime(10.seconds))
+ * }}}
+ */
+ def apply(interval: Duration): ProcessingTime = {
+ new ProcessingTime(interval.toMillis)
+ }
+
+ /**
+ * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible.
+ *
+ * Example:
+ * {{{
+ * def.writer.trigger(ProcessingTime.create("10 seconds"))
+ * }}}
+ */
+ def create(interval: String): ProcessingTime = {
+ apply(interval)
+ }
+
+ /**
+ * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible.
+ *
+ * Example:
+ * {{{
+ * import java.util.concurrent.TimeUnit
+ * def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS))
+ * }}}
+ */
+ def create(interval: Long, unit: TimeUnit): ProcessingTime = {
+ new ProcessingTime(unit.toMillis(interval))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index 14b8b6fc3b..124ec09efd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -22,6 +22,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.columnar.InMemoryRelation
+import org.apache.spark.sql.Dataset
import org.apache.spark.storage.StorageLevel
import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK
@@ -74,12 +75,12 @@ private[sql] class CacheManager extends Logging {
}
/**
- * Caches the data produced by the logical representation of the given [[Queryable]].
+ * Caches the data produced by the logical representation of the given [[Dataset]].
* Unlike `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because
* recomputing the in-memory columnar representation of the underlying table is expensive.
*/
private[sql] def cacheQuery(
- query: Queryable,
+ query: Dataset[_],
tableName: Option[String] = None,
storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock {
val planToCache = query.queryExecution.analyzed
@@ -99,8 +100,8 @@ private[sql] class CacheManager extends Logging {
}
}
- /** Removes the data for the given [[Queryable]] from the cache */
- private[sql] def uncacheQuery(query: Queryable, blocking: Boolean = true): Unit = writeLock {
+ /** Removes the data for the given [[Dataset]] from the cache */
+ private[sql] def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock {
val planToCache = query.queryExecution.analyzed
val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
require(dataIndex >= 0, s"Table $query is not cached.")
@@ -108,11 +109,12 @@ private[sql] class CacheManager extends Logging {
cachedData.remove(dataIndex)
}
- /** Tries to remove the data for the given [[Queryable]] from the cache
- * if it's cached
- */
+ /**
+ * Tries to remove the data for the given [[Dataset]] from the cache
+ * if it's cached
+ */
private[sql] def tryUncacheQuery(
- query: Queryable,
+ query: Dataset[_],
blocking: Boolean = true): Boolean = writeLock {
val planToCache = query.queryExecution.analyzed
val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
@@ -124,8 +126,8 @@ private[sql] class CacheManager extends Logging {
found
}
- /** Optionally returns cached data for the given [[Queryable]] */
- private[sql] def lookupCachedData(query: Queryable): Option[CachedData] = readLock {
+ /** Optionally returns cached data for the given [[Dataset]] */
+ private[sql] def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock {
lookupCachedData(query.queryExecution.analyzed)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 815ff01c4c..392c48fb7b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.datasources.parquet.{DefaultSource => Parq
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation}
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types.{AtomicType, DataType}
object RDDConversions {
def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = {
@@ -123,24 +123,30 @@ private[sql] case class PhysicalRDD(
}
}
-/** Physical plan node for scanning data from a relation. */
-private[sql] case class DataSourceScan(
- output: Seq[Attribute],
- rdd: RDD[InternalRow],
- @transient relation: BaseRelation,
- override val metadata: Map[String, String] = Map.empty)
- extends LeafNode with CodegenSupport {
+private[sql] trait DataSourceScan extends LeafNode {
+ val rdd: RDD[InternalRow]
+ val relation: BaseRelation
override val nodeName: String = relation.toString
// Ignore rdd when checking results
- override def sameResult(plan: SparkPlan ): Boolean = plan match {
+ override def sameResult(plan: SparkPlan): Boolean = plan match {
case other: DataSourceScan => relation == other.relation && metadata == other.metadata
case _ => false
}
+}
- private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+/** Physical plan node for scanning data from a relation. */
+private[sql] case class RowDataSourceScan(
+ output: Seq[Attribute],
+ rdd: RDD[InternalRow],
+ @transient relation: BaseRelation,
+ override val outputPartitioning: Partitioning,
+ override val metadata: Map[String, String] = Map.empty)
+ extends DataSourceScan with CodegenSupport {
+
+ private[sql] override lazy val metrics =
+ Map("numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
val outputUnsafeRows = relation match {
case r: HadoopFsRelation if r.fileFormat.isInstanceOf[ParquetSource] =>
@@ -149,27 +155,6 @@ private[sql] case class DataSourceScan(
case _ => false
}
- override val outputPartitioning = {
- val bucketSpec = relation match {
- // TODO: this should be closer to bucket planning.
- case r: HadoopFsRelation if r.sqlContext.conf.bucketingEnabled => r.bucketSpec
- case _ => None
- }
-
- def toAttribute(colName: String): Attribute = output.find(_.name == colName).getOrElse {
- throw new AnalysisException(s"bucket column $colName not found in existing columns " +
- s"(${output.map(_.name).mkString(", ")})")
- }
-
- bucketSpec.map { spec =>
- val numBuckets = spec.numBuckets
- val bucketColumns = spec.bucketColumnNames.map(toAttribute)
- HashPartitioning(bucketColumns, numBuckets)
- }.getOrElse {
- UnknownPartitioning(0)
- }
- }
-
protected override def doExecute(): RDD[InternalRow] = {
val unsafeRow = if (outputUnsafeRows) {
rdd
@@ -196,6 +181,57 @@ private[sql] case class DataSourceScan(
rdd :: Nil
}
+ override protected def doProduce(ctx: CodegenContext): String = {
+ val numOutputRows = metricTerm(ctx, "numOutputRows")
+ // PhysicalRDD always just has one input
+ val input = ctx.freshName("input")
+ ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
+ val exprRows = output.zipWithIndex.map{ case (a, i) =>
+ new BoundReference(i, a.dataType, a.nullable)
+ }
+ val row = ctx.freshName("row")
+ ctx.INPUT_ROW = row
+ ctx.currentVars = null
+ val columnsRowInput = exprRows.map(_.gen(ctx))
+ val inputRow = if (outputUnsafeRows) row else null
+ s"""
+ |while ($input.hasNext()) {
+ | InternalRow $row = (InternalRow) $input.next();
+ | $numOutputRows.add(1);
+ | ${consume(ctx, columnsRowInput, inputRow).trim}
+ | if (shouldStop()) return;
+ |}
+ """.stripMargin
+ }
+}
+
+/** Physical plan node for scanning data from a batched relation. */
+private[sql] case class BatchedDataSourceScan(
+ output: Seq[Attribute],
+ rdd: RDD[InternalRow],
+ @transient relation: BaseRelation,
+ override val outputPartitioning: Partitioning,
+ override val metadata: Map[String, String] = Map.empty)
+ extends DataSourceScan with CodegenSupport {
+
+ private[sql] override lazy val metrics =
+ Map("numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"),
+ "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time"))
+
+ protected override def doExecute(): RDD[InternalRow] = {
+ throw new UnsupportedOperationException
+ }
+
+ override def simpleString: String = {
+ val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield s"$key: $value"
+ val metadataStr = metadataEntries.mkString(" ", ", ", "")
+ s"BatchedScan $nodeName${output.mkString("[", ",", "]")}$metadataStr"
+ }
+
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ rdd :: Nil
+ }
+
private def genCodeColumnVector(ctx: CodegenContext, columnVar: String, ordinal: String,
dataType: DataType, nullable: Boolean): ExprCode = {
val javaType = ctx.javaType(dataType)
@@ -217,96 +253,64 @@ private[sql] case class DataSourceScan(
// Support codegen so that we can avoid the UnsafeRow conversion in all cases. Codegen
// never requires UnsafeRow as input.
override protected def doProduce(ctx: CodegenContext): String = {
- val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch"
- val columnVectorClz = "org.apache.spark.sql.execution.vectorized.ColumnVector"
val input = ctx.freshName("input")
- val idx = ctx.freshName("batchIdx")
- val rowidx = ctx.freshName("rowIdx")
- val batch = ctx.freshName("batch")
// PhysicalRDD always just has one input
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
+
+ // metrics
+ val numOutputRows = metricTerm(ctx, "numOutputRows")
+ val scanTimeMetric = metricTerm(ctx, "scanTime")
+ val scanTimeTotalNs = ctx.freshName("scanTime")
+ ctx.addMutableState("long", scanTimeTotalNs, s"$scanTimeTotalNs = 0;")
+
+ val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch"
+ val batch = ctx.freshName("batch")
ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;")
+
+ val columnVectorClz = "org.apache.spark.sql.execution.vectorized.ColumnVector"
+ val idx = ctx.freshName("batchIdx")
ctx.addMutableState("int", idx, s"$idx = 0;")
val colVars = output.indices.map(i => ctx.freshName("colInstance" + i))
val columnAssigns = colVars.zipWithIndex.map { case (name, i) =>
ctx.addMutableState(columnVectorClz, name, s"$name = null;")
- s"$name = ${batch}.column($i);" }
-
- val row = ctx.freshName("row")
- val numOutputRows = metricTerm(ctx, "numOutputRows")
-
- // The input RDD can either return (all) ColumnarBatches or InternalRows. We determine this
- // by looking at the first value of the RDD and then calling the function which will process
- // the remaining. It is faster to return batches.
- // TODO: The abstractions between this class and SqlNewHadoopRDD makes it difficult to know
- // here which path to use. Fix this.
+ s"$name = $batch.column($i);"
+ }
- ctx.currentVars = null
- val columns1 = (output zip colVars).map { case (attr, colVar) =>
- genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable) }
- val scanBatches = ctx.freshName("processBatches")
- ctx.addNewFunction(scanBatches,
+ val nextBatch = ctx.freshName("nextBatch")
+ ctx.addNewFunction(nextBatch,
s"""
- | private void $scanBatches() throws java.io.IOException {
- | while (true) {
- | int numRows = $batch.numRows();
- | if ($idx == 0) {
- | ${columnAssigns.mkString("", "\n", "\n")}
- | $numOutputRows.add(numRows);
- | }
- |
- | // this loop is very perf sensitive and changes to it should be measured carefully
- | while ($idx < numRows) {
- | int $rowidx = $idx++;
- | ${consume(ctx, columns1).trim}
- | if (shouldStop()) return;
- | }
- |
- | if (!$input.hasNext()) {
- | $batch = null;
- | break;
- | }
- | $batch = ($columnarBatchClz)$input.next();
- | $idx = 0;
- | }
- | }""".stripMargin)
-
- val exprRows =
- output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, x._1.nullable))
- ctx.INPUT_ROW = row
+ |private void $nextBatch() throws java.io.IOException {
+ | long getBatchStart = System.nanoTime();
+ | if ($input.hasNext()) {
+ | $batch = ($columnarBatchClz)$input.next();
+ | $numOutputRows.add($batch.numRows());
+ | $idx = 0;
+ | ${columnAssigns.mkString("", "\n", "\n")}
+ | }
+ | $scanTimeTotalNs += System.nanoTime() - getBatchStart;
+ |}""".stripMargin)
+
ctx.currentVars = null
- val columns2 = exprRows.map(_.gen(ctx))
- val inputRow = if (outputUnsafeRows) row else null
- val scanRows = ctx.freshName("processRows")
- ctx.addNewFunction(scanRows,
- s"""
- | private void $scanRows(InternalRow $row) throws java.io.IOException {
- | boolean firstRow = true;
- | while (firstRow || $input.hasNext()) {
- | if (firstRow) {
- | firstRow = false;
- | } else {
- | $row = (InternalRow) $input.next();
- | }
- | $numOutputRows.add(1);
- | ${consume(ctx, columns2, inputRow).trim}
- | if (shouldStop()) return;
- | }
- | }""".stripMargin)
-
- val value = ctx.freshName("value")
+ val rowidx = ctx.freshName("rowIdx")
+ val columnsBatchInput = (output zip colVars).map { case (attr, colVar) =>
+ genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable)
+ }
s"""
- | if ($batch != null) {
- | $scanBatches();
- | } else if ($input.hasNext()) {
- | Object $value = $input.next();
- | if ($value instanceof $columnarBatchClz) {
- | $batch = ($columnarBatchClz)$value;
- | $scanBatches();
- | } else {
- | $scanRows((InternalRow) $value);
- | }
- | }
+ |if ($batch == null) {
+ | $nextBatch();
+ |}
+ |while ($batch != null) {
+ | int numRows = $batch.numRows();
+ | while ($idx < numRows) {
+ | int $rowidx = $idx++;
+ | ${consume(ctx, columnsBatchInput).trim}
+ | if (shouldStop()) return;
+ | }
+ | $batch = null;
+ | $nextBatch();
+ |}
+ |$scanTimeMetric.add($scanTimeTotalNs / (1000 * 1000));
+ |$scanTimeTotalNs = 0;
""".stripMargin
}
}
@@ -315,4 +319,38 @@ private[sql] object DataSourceScan {
// Metadata keys
val INPUT_PATHS = "InputPaths"
val PUSHED_FILTERS = "PushedFilters"
+
+ def create(
+ output: Seq[Attribute],
+ rdd: RDD[InternalRow],
+ relation: BaseRelation,
+ metadata: Map[String, String] = Map.empty): DataSourceScan = {
+ val outputPartitioning = {
+ val bucketSpec = relation match {
+ // TODO: this should be closer to bucket planning.
+ case r: HadoopFsRelation if r.sqlContext.conf.bucketingEnabled => r.bucketSpec
+ case _ => None
+ }
+
+ def toAttribute(colName: String): Attribute = output.find(_.name == colName).getOrElse {
+ throw new AnalysisException(s"bucket column $colName not found in existing columns " +
+ s"(${output.map(_.name).mkString(", ")})")
+ }
+
+ bucketSpec.map { spec =>
+ val numBuckets = spec.numBuckets
+ val bucketColumns = spec.bucketColumnNames.map(toAttribute)
+ HashPartitioning(bucketColumns, numBuckets)
+ }.getOrElse {
+ UnknownPartitioning(0)
+ }
+ }
+
+ relation match {
+ case r: HadoopFsRelation if r.fileFormat.supportBatch(r.sqlContext, relation.schema) =>
+ BatchedDataSourceScan(output, rdd, relation, outputPartitioning, metadata)
+ case _ =>
+ RowDataSourceScan(output, rdd, relation, outputPartitioning, metadata)
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 912b84abc1..f5e1e77263 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -21,6 +21,8 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange}
/**
* The primary workflow for executing relational queries using Spark. Designed to allow easy
@@ -31,6 +33,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
*/
class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {
+ // TODO: Move the planner an optimizer into here from SessionState.
+ protected def planner = sqlContext.sessionState.planner
+
def assertAnalyzed(): Unit = try sqlContext.sessionState.analyzer.checkAnalysis(analyzed) catch {
case e: AnalysisException =>
val ae = new AnalysisException(e.message, e.line, e.startPosition, Some(analyzed))
@@ -49,16 +54,32 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {
lazy val sparkPlan: SparkPlan = {
SQLContext.setActive(sqlContext)
- sqlContext.sessionState.planner.plan(ReturnAnswer(optimizedPlan)).next()
+ planner.plan(ReturnAnswer(optimizedPlan)).next()
}
// executedPlan should not be used to initialize any SparkPlan. It should be
// only used for execution.
- lazy val executedPlan: SparkPlan = sqlContext.sessionState.prepareForExecution.execute(sparkPlan)
+ lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan)
/** Internal version of the RDD. Avoids copies and has no schema */
lazy val toRdd: RDD[InternalRow] = executedPlan.execute()
+ /**
+ * Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal
+ * row format conversions as needed.
+ */
+ protected def prepareForExecution(plan: SparkPlan): SparkPlan = {
+ preparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp) }
+ }
+
+ /** A sequence of rules that will be applied in order to the physical plan before execution. */
+ protected def preparations: Seq[Rule[SparkPlan]] = Seq(
+ python.ExtractPythonUDFs,
+ PlanSubqueries(sqlContext),
+ EnsureRequirements(sqlContext.conf),
+ CollapseCodegenStages(sqlContext.conf),
+ ReuseExchange(sqlContext.conf))
+
protected def stringOrError[A](f: => A): String =
try f.toString catch { case e: Throwable => e.toString }
@@ -83,4 +104,20 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {
|${stringOrError(executedPlan)}
""".stripMargin.trim
}
+
+ /** A special namespace for commands that can be used to debug query execution. */
+ // scalastyle:off
+ object debug {
+ // scalastyle:on
+
+ /**
+ * Prints to stdout all the generated code found in this plan (i.e. the output of each
+ * WholeStageCodegen subtree).
+ */
+ def codegen(): Unit = {
+ // scalastyle:off println
+ println(org.apache.spark.sql.execution.debug.codegenString(executedPlan))
+ // scalastyle:on println
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala
deleted file mode 100644
index 38263af0f7..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala
+++ /dev/null
@@ -1,124 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution
-
-import scala.util.control.NonFatal
-
-import org.apache.commons.lang3.StringUtils
-
-import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.types.StructType
-
-/** A trait that holds shared code between DataFrames and Datasets. */
-private[sql] trait Queryable {
- def schema: StructType
- def queryExecution: QueryExecution
- def sqlContext: SQLContext
-
- override def toString: String = {
- try {
- val builder = new StringBuilder
- val fields = schema.take(2).map {
- case f => s"${f.name}: ${f.dataType.simpleString(2)}"
- }
- builder.append("[")
- builder.append(fields.mkString(", "))
- if (schema.length > 2) {
- if (schema.length - fields.size == 1) {
- builder.append(" ... 1 more field")
- } else {
- builder.append(" ... " + (schema.length - 2) + " more fields")
- }
- }
- builder.append("]").toString()
- } catch {
- case NonFatal(e) =>
- s"Invalid tree; ${e.getMessage}:\n$queryExecution"
- }
- }
-
- def printSchema(): Unit
-
- def explain(extended: Boolean): Unit
-
- def explain(): Unit
-
- private[sql] def showString(_numRows: Int, truncate: Boolean = true): String
-
- /**
- * Format the string representing rows for output
- * @param rows The rows to show
- * @param numRows Number of rows to show
- * @param hasMoreData Whether some rows are not shown due to the limit
- * @param truncate Whether truncate long strings and align cells right
- *
- */
- private[sql] def formatString (
- rows: Seq[Seq[String]],
- numRows: Int,
- hasMoreData : Boolean,
- truncate: Boolean = true): String = {
- val sb = new StringBuilder
- val numCols = schema.fieldNames.length
-
- // Initialise the width of each column to a minimum value of '3'
- val colWidths = Array.fill(numCols)(3)
-
- // Compute the width of each column
- for (row <- rows) {
- for ((cell, i) <- row.zipWithIndex) {
- colWidths(i) = math.max(colWidths(i), cell.length)
- }
- }
-
- // Create SeparateLine
- val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString()
-
- // column names
- rows.head.zipWithIndex.map { case (cell, i) =>
- if (truncate) {
- StringUtils.leftPad(cell, colWidths(i))
- } else {
- StringUtils.rightPad(cell, colWidths(i))
- }
- }.addString(sb, "|", "|", "|\n")
-
- sb.append(sep)
-
- // data
- rows.tail.map {
- _.zipWithIndex.map { case (cell, i) =>
- if (truncate) {
- StringUtils.leftPad(cell.toString, colWidths(i))
- } else {
- StringUtils.rightPad(cell.toString, colWidths(i))
- }
- }.addString(sb, "|", "|", "|\n")
- }
-
- sb.append(sep)
-
- // For Data that has more than "numRows" records
- if (hasMoreData) {
- val rowsString = if (numRows == 1) "row" else "rows"
- sb.append(s"only showing top $numRows $rowsString\n")
- }
-
- sb.toString()
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 010ed7f500..4091f65aec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -84,8 +84,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
private[sql] def metrics: Map[String, SQLMetric[_, _]] = Map.empty
/**
- * Reset all the metrics.
- */
+ * Reset all the metrics.
+ */
private[sql] def resetMetrics(): Unit = {
metrics.valuesIterator.foreach(_.reset())
}
@@ -249,20 +249,24 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/**
* Decode the byte arrays back to UnsafeRows and put them into buffer.
*/
- private def decodeUnsafeRows(bytes: Array[Byte], buffer: ArrayBuffer[InternalRow]): Unit = {
+ private def decodeUnsafeRows(bytes: Array[Byte]): Iterator[InternalRow] = {
val nFields = schema.length
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val bis = new ByteArrayInputStream(bytes)
val ins = new DataInputStream(codec.compressedInputStream(bis))
- var sizeOfNextRow = ins.readInt()
- while (sizeOfNextRow >= 0) {
- val bs = new Array[Byte](sizeOfNextRow)
- ins.readFully(bs)
- val row = new UnsafeRow(nFields)
- row.pointTo(bs, sizeOfNextRow)
- buffer += row
- sizeOfNextRow = ins.readInt()
+
+ new Iterator[InternalRow] {
+ private var sizeOfNextRow = ins.readInt()
+ override def hasNext: Boolean = sizeOfNextRow >= 0
+ override def next(): InternalRow = {
+ val bs = new Array[Byte](sizeOfNextRow)
+ ins.readFully(bs)
+ val row = new UnsafeRow(nFields)
+ row.pointTo(bs, sizeOfNextRow)
+ sizeOfNextRow = ins.readInt()
+ row
+ }
}
}
@@ -274,12 +278,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
val results = ArrayBuffer[InternalRow]()
byteArrayRdd.collect().foreach { bytes =>
- decodeUnsafeRows(bytes, results)
+ decodeUnsafeRows(bytes).foreach(results.+=)
}
results.toArray
}
/**
+ * Runs this query returning the result as an iterator of InternalRow.
+ *
+ * Note: this will trigger multiple jobs (one for each partition).
+ */
+ def executeToIterator(): Iterator[InternalRow] = {
+ getByteArrayRdd().toLocalIterator.flatMap(decodeUnsafeRows)
+ }
+
+ /**
* Runs this query returning the result as an array, using external Row format.
*/
def executeCollectPublic(): Array[Row] = {
@@ -325,7 +338,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
(it: Iterator[Array[Byte]]) => if (it.hasNext) it.next() else Array.empty, p)
res.foreach { r =>
- decodeUnsafeRows(r.asInstanceOf[Array[Byte]], buf)
+ decodeUnsafeRows(r.asInstanceOf[Array[Byte]]).foreach(buf.+=)
}
partsScanned += p.size
@@ -379,6 +392,13 @@ private[sql] trait LeafNode extends SparkPlan {
override def producedAttributes: AttributeSet = outputSet
}
+object UnaryNode {
+ def unapply(a: Any): Option[(SparkPlan, SparkPlan)] = a match {
+ case s: SparkPlan if s.children.size == 1 => Some((s, s.children.head))
+ case _ => None
+ }
+}
+
private[sql] trait UnaryNode extends SparkPlan {
def child: SparkPlan
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
index 9da2c74c62..8d05ae470d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
@@ -26,19 +26,19 @@ import org.apache.spark.sql.internal.SQLConf
class SparkPlanner(
val sparkContext: SparkContext,
val conf: SQLConf,
- val experimentalMethods: ExperimentalMethods)
+ val extraStrategies: Seq[Strategy])
extends SparkStrategies {
def numPartitions: Int = conf.numShufflePartitions
def strategies: Seq[Strategy] =
- experimentalMethods.extraStrategies ++ (
+ extraStrategies ++ (
FileSourceStrategy ::
DataSourceStrategy ::
DDLStrategy ::
SpecialLimits ::
Aggregation ::
- LeftSemiJoin ::
+ ExistenceJoin ::
EquiJoinSelection ::
InMemoryScans ::
BasicOperators ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala
deleted file mode 100644
index b9542c7173..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala
+++ /dev/null
@@ -1,329 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.sql.execution
-
-import org.apache.spark.sql.{AnalysisException, SaveMode}
-import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.parser._
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation}
-import org.apache.spark.sql.execution.command._
-import org.apache.spark.sql.execution.datasources._
-import org.apache.spark.sql.types.StructType
-
-private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends CatalystQl(conf) {
- import ParserUtils._
-
- /** Check if a command should not be explained. */
- protected def isNoExplainCommand(command: String): Boolean = {
- "TOK_DESCTABLE" == command || "TOK_ALTERTABLE" == command
- }
-
- /**
- * For each node, extract properties in the form of a list
- * ['key_part1', 'key_part2', 'key_part3', 'value']
- * into a pair (key_part1.key_part2.key_part3, value).
- *
- * Example format:
- *
- * TOK_TABLEPROPERTY
- * :- 'k1'
- * +- 'v1'
- * TOK_TABLEPROPERTY
- * :- 'k2'
- * +- 'v2'
- * TOK_TABLEPROPERTY
- * :- 'k3'
- * +- 'v3'
- */
- private def extractProps(
- props: Seq[ASTNode],
- expectedNodeText: String): Seq[(String, String)] = {
- props.map {
- case Token(x, keysAndValue) if x == expectedNodeText =>
- val key = keysAndValue.init.map { x => unquoteString(x.text) }.mkString(".")
- val value = unquoteString(keysAndValue.last.text)
- (key, value)
- case p =>
- parseFailed(s"Expected property '$expectedNodeText' in command", p)
- }
- }
-
- protected override def nodeToPlan(node: ASTNode): LogicalPlan = {
- node match {
- case Token("TOK_SETCONFIG", Nil) =>
- val keyValueSeparatorIndex = node.remainder.indexOf('=')
- if (keyValueSeparatorIndex >= 0) {
- val key = node.remainder.substring(0, keyValueSeparatorIndex).trim
- val value = node.remainder.substring(keyValueSeparatorIndex + 1).trim
- SetCommand(Some(key -> Option(value)))
- } else if (node.remainder.nonEmpty) {
- SetCommand(Some(node.remainder -> None))
- } else {
- SetCommand(None)
- }
-
- // Just fake explain for any of the native commands.
- case Token("TOK_EXPLAIN", explainArgs) if isNoExplainCommand(explainArgs.head.text) =>
- ExplainCommand(OneRowRelation)
-
- case Token("TOK_EXPLAIN", explainArgs) if "TOK_CREATETABLE" == explainArgs.head.text =>
- val Some(crtTbl) :: _ :: extended :: Nil =
- getClauses(Seq("TOK_CREATETABLE", "FORMATTED", "EXTENDED"), explainArgs)
- ExplainCommand(nodeToPlan(crtTbl), extended = extended.isDefined)
-
- case Token("TOK_EXPLAIN", explainArgs) =>
- // Ignore FORMATTED if present.
- val Some(query) :: _ :: extended :: Nil =
- getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs)
- ExplainCommand(nodeToPlan(query), extended = extended.isDefined)
-
- case Token("TOK_REFRESHTABLE", nameParts :: Nil) =>
- val tableIdent = extractTableIdent(nameParts)
- RefreshTable(tableIdent)
-
- // CREATE DATABASE [IF NOT EXISTS] database_name [COMMENT database_comment]
- // [LOCATION path] [WITH DBPROPERTIES (key1=val1, key2=val2, ...)];
- case Token("TOK_CREATEDATABASE", Token(databaseName, Nil) :: args) =>
- val Seq(ifNotExists, dbLocation, databaseComment, dbprops) = getClauses(Seq(
- "TOK_IFNOTEXISTS",
- "TOK_DATABASELOCATION",
- "TOK_DATABASECOMMENT",
- "TOK_DATABASEPROPERTIES"), args)
- val location = dbLocation.map {
- case Token("TOK_DATABASELOCATION", Token(loc, Nil) :: Nil) => unquoteString(loc)
- case _ => parseFailed("Invalid CREATE DATABASE command", node)
- }
- val comment = databaseComment.map {
- case Token("TOK_DATABASECOMMENT", Token(com, Nil) :: Nil) => unquoteString(com)
- case _ => parseFailed("Invalid CREATE DATABASE command", node)
- }
- val props = dbprops.toSeq.flatMap {
- case Token("TOK_DATABASEPROPERTIES", Token("TOK_DBPROPLIST", propList) :: Nil) =>
- // Example format:
- //
- // TOK_DATABASEPROPERTIES
- // +- TOK_DBPROPLIST
- // :- TOK_TABLEPROPERTY
- // : :- 'k1'
- // : +- 'v1'
- // :- TOK_TABLEPROPERTY
- // :- 'k2'
- // +- 'v2'
- extractProps(propList, "TOK_TABLEPROPERTY")
- case _ => parseFailed("Invalid CREATE DATABASE command", node)
- }.toMap
- CreateDatabase(databaseName, ifNotExists.isDefined, location, comment, props)(node.source)
-
- // DROP DATABASE [IF EXISTS] database_name [RESTRICT|CASCADE];
- case Token("TOK_DROPDATABASE", Token(dbName, Nil) :: otherArgs) =>
- // Example format:
- //
- // TOK_DROPDATABASE
- // :- database_name
- // :- TOK_IFEXISTS
- // +- TOK_RESTRICT/TOK_CASCADE
- val databaseName = unquoteString(dbName)
- // The default is RESTRICT
- val Seq(ifExists, _, cascade) = getClauses(Seq(
- "TOK_IFEXISTS", "TOK_RESTRICT", "TOK_CASCADE"), otherArgs)
- DropDatabase(databaseName, ifExists.isDefined, restrict = cascade.isEmpty)(node.source)
-
- // CREATE [TEMPORARY] FUNCTION [db_name.]function_name AS class_name
- // [USING JAR|FILE|ARCHIVE 'file_uri' [, JAR|FILE|ARCHIVE 'file_uri'] ];
- case Token("TOK_CREATEFUNCTION", args) =>
- // Example format:
- //
- // TOK_CREATEFUNCTION
- // :- db_name
- // :- func_name
- // :- alias
- // +- TOK_RESOURCE_LIST
- // :- TOK_RESOURCE_URI
- // : :- TOK_JAR
- // : +- '/path/to/jar'
- // +- TOK_RESOURCE_URI
- // :- TOK_FILE
- // +- 'path/to/file'
- val (funcNameArgs, otherArgs) = args.partition {
- case Token("TOK_RESOURCE_LIST", _) => false
- case Token("TOK_TEMPORARY", _) => false
- case Token(_, Nil) => true
- case _ => parseFailed("Invalid CREATE FUNCTION command", node)
- }
- // If database name is specified, there are 3 tokens, otherwise 2.
- val (funcName, alias) = funcNameArgs match {
- case Token(dbName, Nil) :: Token(fname, Nil) :: Token(aname, Nil) :: Nil =>
- (unquoteString(dbName) + "." + unquoteString(fname), unquoteString(aname))
- case Token(fname, Nil) :: Token(aname, Nil) :: Nil =>
- (unquoteString(fname), unquoteString(aname))
- case _ =>
- parseFailed("Invalid CREATE FUNCTION command", node)
- }
- // Extract other keywords, if they exist
- val Seq(rList, temp) = getClauses(Seq("TOK_RESOURCE_LIST", "TOK_TEMPORARY"), otherArgs)
- val resources: Seq[(String, String)] = rList.toSeq.flatMap {
- case Token("TOK_RESOURCE_LIST", resList) =>
- resList.map {
- case Token("TOK_RESOURCE_URI", rType :: Token(rPath, Nil) :: Nil) =>
- val resourceType = rType match {
- case Token("TOK_JAR", Nil) => "jar"
- case Token("TOK_FILE", Nil) => "file"
- case Token("TOK_ARCHIVE", Nil) => "archive"
- case Token(f, _) => parseFailed(s"Unexpected resource format '$f'", node)
- }
- (resourceType, unquoteString(rPath))
- case _ => parseFailed("Invalid CREATE FUNCTION command", node)
- }
- case _ => parseFailed("Invalid CREATE FUNCTION command", node)
- }
- CreateFunction(funcName, alias, resources, temp.isDefined)(node.source)
-
- case Token("TOK_ALTERTABLE", alterTableArgs) =>
- AlterTableCommandParser.parse(node)
-
- case Token("TOK_CREATETABLEUSING", createTableArgs) =>
- val Seq(
- temp,
- ifNotExists,
- Some(tabName),
- tableCols,
- Some(Token("TOK_TABLEPROVIDER", providerNameParts)),
- tableOpts,
- tableAs) = getClauses(Seq(
- "TEMPORARY",
- "TOK_IFNOTEXISTS",
- "TOK_TABNAME", "TOK_TABCOLLIST",
- "TOK_TABLEPROVIDER",
- "TOK_TABLEOPTIONS",
- "TOK_QUERY"), createTableArgs)
- val tableIdent: TableIdentifier = extractTableIdent(tabName)
- val columns = tableCols.map {
- case Token("TOK_TABCOLLIST", fields) => StructType(fields.map(nodeToStructField))
- case _ => parseFailed("Invalid CREATE TABLE command", node)
- }
- val provider = providerNameParts.map {
- case Token(name, Nil) => name
- case _ => parseFailed("Invalid CREATE TABLE command", node)
- }.mkString(".")
- val options = tableOpts.toSeq.flatMap {
- case Token("TOK_TABLEOPTIONS", opts) => extractProps(opts, "TOK_TABLEOPTION")
- case _ => parseFailed("Invalid CREATE TABLE command", node)
- }.toMap
- val asClause = tableAs.map(nodeToPlan)
-
- if (temp.isDefined && ifNotExists.isDefined) {
- throw new AnalysisException(
- "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.")
- }
-
- if (asClause.isDefined) {
- if (columns.isDefined) {
- throw new AnalysisException(
- "a CREATE TABLE AS SELECT statement does not allow column definitions.")
- }
-
- val mode = if (ifNotExists.isDefined) {
- SaveMode.Ignore
- } else if (temp.isDefined) {
- SaveMode.Overwrite
- } else {
- SaveMode.ErrorIfExists
- }
-
- CreateTableUsingAsSelect(tableIdent,
- provider,
- temp.isDefined,
- Array.empty[String],
- bucketSpec = None,
- mode,
- options,
- asClause.get)
- } else {
- CreateTableUsing(
- tableIdent,
- columns,
- provider,
- temp.isDefined,
- options,
- ifNotExists.isDefined,
- managedIfNoPath = false)
- }
-
- case Token("TOK_SWITCHDATABASE", Token(database, Nil) :: Nil) =>
- SetDatabaseCommand(cleanIdentifier(database))
-
- case Token("TOK_DESCTABLE", describeArgs) =>
- // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL
- val Some(tableType) :: formatted :: extended :: pretty :: Nil =
- getClauses(Seq("TOK_TABTYPE", "FORMATTED", "EXTENDED", "PRETTY"), describeArgs)
- if (formatted.isDefined || pretty.isDefined) {
- // FORMATTED and PRETTY are not supported and this statement will be treated as
- // a Hive native command.
- nodeToDescribeFallback(node)
- } else {
- tableType match {
- case Token("TOK_TABTYPE", Token("TOK_TABNAME", nameParts) :: Nil) =>
- nameParts match {
- case Token(dbName, Nil) :: Token(tableName, Nil) :: Nil =>
- // It is describing a table with the format like "describe db.table".
- // TODO: Actually, a user may mean tableName.columnName. Need to resolve this
- // issue.
- val tableIdent = TableIdentifier(
- cleanIdentifier(tableName), Some(cleanIdentifier(dbName)))
- datasources.DescribeCommand(tableIdent, isExtended = extended.isDefined)
- case Token(dbName, Nil) :: Token(tableName, Nil) :: Token(colName, Nil) :: Nil =>
- // It is describing a column with the format like "describe db.table column".
- nodeToDescribeFallback(node)
- case tableName :: Nil =>
- // It is describing a table with the format like "describe table".
- datasources.DescribeCommand(
- TableIdentifier(cleanIdentifier(tableName.text)),
- isExtended = extended.isDefined)
- case _ =>
- nodeToDescribeFallback(node)
- }
- // All other cases.
- case _ =>
- nodeToDescribeFallback(node)
- }
- }
-
- case Token("TOK_CACHETABLE", Token(tableName, Nil) :: args) =>
- val Seq(lzy, selectAst) = getClauses(Seq("LAZY", "TOK_QUERY"), args)
- CacheTableCommand(tableName, selectAst.map(nodeToPlan), lzy.isDefined)
-
- case Token("TOK_UNCACHETABLE", Token(tableName, Nil) :: Nil) =>
- UncacheTableCommand(tableName)
-
- case Token("TOK_CLEARCACHE", Nil) =>
- ClearCacheCommand
-
- case Token("TOK_SHOWTABLES", args) =>
- val databaseName = args match {
- case Nil => None
- case Token("TOK_FROM", Token(dbName, Nil) :: Nil) :: Nil => Option(dbName)
- case _ => noParseRule("SHOW TABLES", node)
- }
- ShowTablesCommand(databaseName)
-
- case _ =>
- super.nodeToPlan(node)
- }
- }
-
- protected def nodeToDescribeFallback(node: ASTNode): LogicalPlan = noParseRule("Describe", node)
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
new file mode 100644
index 0000000000..8ed6ed21d0
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -0,0 +1,792 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.{AnalysisException, SaveMode}
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.parser._
+import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation}
+import org.apache.spark.sql.execution.command.{DescribeCommand => _, _}
+import org.apache.spark.sql.execution.datasources._
+
+/**
+ * Concrete parser for Spark SQL statements.
+ */
+object SparkSqlParser extends AbstractSqlParser{
+ val astBuilder = new SparkSqlAstBuilder
+}
+
+/**
+ * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier.
+ */
+class SparkSqlAstBuilder extends AstBuilder {
+ import org.apache.spark.sql.catalyst.parser.ParserUtils._
+
+ /**
+ * Create a [[SetCommand]] logical plan.
+ *
+ * Note that we assume that everything after the SET keyword is assumed to be a part of the
+ * key-value pair. The split between key and value is made by searching for the first `=`
+ * character in the raw string.
+ */
+ override def visitSetConfiguration(ctx: SetConfigurationContext): LogicalPlan = withOrigin(ctx) {
+ // Construct the command.
+ val raw = remainder(ctx.SET.getSymbol)
+ val keyValueSeparatorIndex = raw.indexOf('=')
+ if (keyValueSeparatorIndex >= 0) {
+ val key = raw.substring(0, keyValueSeparatorIndex).trim
+ val value = raw.substring(keyValueSeparatorIndex + 1).trim
+ SetCommand(Some(key -> Option(value)))
+ } else if (raw.nonEmpty) {
+ SetCommand(Some(raw.trim -> None))
+ } else {
+ SetCommand(None)
+ }
+ }
+
+ /**
+ * Create a [[SetDatabaseCommand]] logical plan.
+ */
+ override def visitUse(ctx: UseContext): LogicalPlan = withOrigin(ctx) {
+ SetDatabaseCommand(ctx.db.getText)
+ }
+
+ /**
+ * Create a [[ShowTablesCommand]] logical plan.
+ * Example SQL :
+ * {{{
+ * SHOW TABLES [(IN|FROM) database_name] [[LIKE] 'identifier_with_wildcards'];
+ * }}}
+ */
+ override def visitShowTables(ctx: ShowTablesContext): LogicalPlan = withOrigin(ctx) {
+ ShowTablesCommand(
+ Option(ctx.db).map(_.getText),
+ Option(ctx.pattern).map(string))
+ }
+
+ /**
+ * Create a [[ShowDatabasesCommand]] logical plan.
+ * Example SQL:
+ * {{{
+ * SHOW (DATABASES|SCHEMAS) [LIKE 'identifier_with_wildcards'];
+ * }}}
+ */
+ override def visitShowDatabases(ctx: ShowDatabasesContext): LogicalPlan = withOrigin(ctx) {
+ ShowDatabasesCommand(Option(ctx.pattern).map(string))
+ }
+
+ /**
+ * A command for users to list the properties for a table. If propertyKey is specified, the value
+ * for the propertyKey is returned. If propertyKey is not specified, all the keys and their
+ * corresponding values are returned.
+ * The syntax of using this command in SQL is:
+ * {{{
+ * SHOW TBLPROPERTIES table_name[('propertyKey')];
+ * }}}
+ */
+ override def visitShowTblProperties(
+ ctx: ShowTblPropertiesContext): LogicalPlan = withOrigin(ctx) {
+ ShowTablePropertiesCommand(
+ visitTableIdentifier(ctx.tableIdentifier),
+ Option(ctx.key).map(visitTablePropertyKey))
+ }
+
+ /**
+ * Create a [[RefreshTable]] logical plan.
+ */
+ override def visitRefreshTable(ctx: RefreshTableContext): LogicalPlan = withOrigin(ctx) {
+ RefreshTable(visitTableIdentifier(ctx.tableIdentifier))
+ }
+
+ /**
+ * Create a [[CacheTableCommand]] logical plan.
+ */
+ override def visitCacheTable(ctx: CacheTableContext): LogicalPlan = withOrigin(ctx) {
+ val query = Option(ctx.query).map(plan)
+ CacheTableCommand(ctx.identifier.getText, query, ctx.LAZY != null)
+ }
+
+ /**
+ * Create an [[UncacheTableCommand]] logical plan.
+ */
+ override def visitUncacheTable(ctx: UncacheTableContext): LogicalPlan = withOrigin(ctx) {
+ UncacheTableCommand(ctx.identifier.getText)
+ }
+
+ /**
+ * Create a [[ClearCacheCommand]] logical plan.
+ */
+ override def visitClearCache(ctx: ClearCacheContext): LogicalPlan = withOrigin(ctx) {
+ ClearCacheCommand
+ }
+
+ /**
+ * Create an [[ExplainCommand]] logical plan.
+ */
+ override def visitExplain(ctx: ExplainContext): LogicalPlan = withOrigin(ctx) {
+ val options = ctx.explainOption.asScala
+ if (options.exists(_.FORMATTED != null)) {
+ logWarning("Unsupported operation: EXPLAIN FORMATTED option")
+ }
+
+ // Create the explain comment.
+ val statement = plan(ctx.statement)
+ if (isExplainableStatement(statement)) {
+ ExplainCommand(statement, extended = options.exists(_.EXTENDED != null),
+ codegen = options.exists(_.CODEGEN != null))
+ } else {
+ ExplainCommand(OneRowRelation)
+ }
+ }
+
+ /**
+ * Determine if a plan should be explained at all.
+ */
+ protected def isExplainableStatement(plan: LogicalPlan): Boolean = plan match {
+ case _: datasources.DescribeCommand => false
+ case _ => true
+ }
+
+ /**
+ * Create a [[DescribeCommand]] logical plan.
+ */
+ override def visitDescribeTable(ctx: DescribeTableContext): LogicalPlan = withOrigin(ctx) {
+ // FORMATTED and columns are not supported. Return null and let the parser decide what to do
+ // with this (create an exception or pass it on to a different system).
+ if (ctx.describeColName != null || ctx.FORMATTED != null || ctx.partitionSpec != null) {
+ null
+ } else {
+ datasources.DescribeCommand(
+ visitTableIdentifier(ctx.tableIdentifier),
+ ctx.EXTENDED != null)
+ }
+ }
+
+ /**
+ * Type to keep track of a table header: (identifier, isTemporary, ifNotExists, isExternal).
+ */
+ type TableHeader = (TableIdentifier, Boolean, Boolean, Boolean)
+
+ /**
+ * Validate a create table statement and return the [[TableIdentifier]].
+ */
+ override def visitCreateTableHeader(
+ ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) {
+ val temporary = ctx.TEMPORARY != null
+ val ifNotExists = ctx.EXISTS != null
+ assert(!temporary || !ifNotExists,
+ "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.",
+ ctx)
+ (visitTableIdentifier(ctx.tableIdentifier), temporary, ifNotExists, ctx.EXTERNAL != null)
+ }
+
+ /**
+ * Create a [[CreateTableUsing]] or a [[CreateTableUsingAsSelect]] logical plan.
+ *
+ * TODO add bucketing and partitioning.
+ */
+ override def visitCreateTableUsing(ctx: CreateTableUsingContext): LogicalPlan = withOrigin(ctx) {
+ val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader)
+ if (external) {
+ throw new ParseException("Unsupported operation: EXTERNAL option", ctx)
+ }
+ val options = Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty)
+ val provider = ctx.tableProvider.qualifiedName.getText
+
+ if (ctx.query != null) {
+ // Get the backing query.
+ val query = plan(ctx.query)
+
+ // Determine the storage mode.
+ val mode = if (ifNotExists) {
+ SaveMode.Ignore
+ } else if (temp) {
+ SaveMode.Overwrite
+ } else {
+ SaveMode.ErrorIfExists
+ }
+ CreateTableUsingAsSelect(table, provider, temp, Array.empty, None, mode, options, query)
+ } else {
+ val struct = Option(ctx.colTypeList).map(createStructType)
+ CreateTableUsing(table, struct, provider, temp, options, ifNotExists, managedIfNoPath = false)
+ }
+ }
+
+ /**
+ * Convert a table property list into a key-value map.
+ */
+ override def visitTablePropertyList(
+ ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) {
+ ctx.tableProperty.asScala.map { property =>
+ val key = visitTablePropertyKey(property.key)
+ val value = Option(property.value).map(string).orNull
+ key -> value
+ }.toMap
+ }
+
+ /**
+ * A table property key can either be String or a collection of dot separated elements. This
+ * function extracts the property key based on whether its a string literal or a table property
+ * identifier.
+ */
+ override def visitTablePropertyKey(key: TablePropertyKeyContext): String = {
+ if (key.STRING != null) {
+ string(key.STRING)
+ } else {
+ key.getText
+ }
+ }
+
+ /**
+ * Create a [[CreateDatabase]] command.
+ *
+ * For example:
+ * {{{
+ * CREATE DATABASE [IF NOT EXISTS] database_name [COMMENT database_comment]
+ * [LOCATION path] [WITH DBPROPERTIES (key1=val1, key2=val2, ...)]
+ * }}}
+ */
+ override def visitCreateDatabase(ctx: CreateDatabaseContext): LogicalPlan = withOrigin(ctx) {
+ CreateDatabase(
+ ctx.identifier.getText,
+ ctx.EXISTS != null,
+ Option(ctx.locationSpec).map(visitLocationSpec),
+ Option(ctx.comment).map(string),
+ Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty))
+ }
+
+ /**
+ * Create an [[AlterDatabaseProperties]] command.
+ *
+ * For example:
+ * {{{
+ * ALTER (DATABASE|SCHEMA) database SET DBPROPERTIES (property_name=property_value, ...);
+ * }}}
+ */
+ override def visitSetDatabaseProperties(
+ ctx: SetDatabasePropertiesContext): LogicalPlan = withOrigin(ctx) {
+ AlterDatabaseProperties(
+ ctx.identifier.getText,
+ visitTablePropertyList(ctx.tablePropertyList))
+ }
+
+ /**
+ * Create a [[DropDatabase]] command.
+ *
+ * For example:
+ * {{{
+ * DROP (DATABASE|SCHEMA) [IF EXISTS] database [RESTRICT|CASCADE];
+ * }}}
+ */
+ override def visitDropDatabase(ctx: DropDatabaseContext): LogicalPlan = withOrigin(ctx) {
+ DropDatabase(ctx.identifier.getText, ctx.EXISTS != null, ctx.CASCADE != null)
+ }
+
+ /**
+ * Create a [[DescribeDatabase]] command.
+ *
+ * For example:
+ * {{{
+ * DESCRIBE DATABASE [EXTENDED] database;
+ * }}}
+ */
+ override def visitDescribeDatabase(ctx: DescribeDatabaseContext): LogicalPlan = withOrigin(ctx) {
+ DescribeDatabase(ctx.identifier.getText, ctx.EXTENDED != null)
+ }
+
+ /**
+ * Create a [[CreateFunction]] command.
+ *
+ * For example:
+ * {{{
+ * CREATE [TEMPORARY] FUNCTION [db_name.]function_name AS class_name
+ * [USING JAR|FILE|ARCHIVE 'file_uri' [, JAR|FILE|ARCHIVE 'file_uri']];
+ * }}}
+ */
+ override def visitCreateFunction(ctx: CreateFunctionContext): LogicalPlan = withOrigin(ctx) {
+ val resources = ctx.resource.asScala.map { resource =>
+ val resourceType = resource.identifier.getText.toLowerCase
+ resourceType match {
+ case "jar" | "file" | "archive" =>
+ resourceType -> string(resource.STRING)
+ case other =>
+ throw new ParseException(s"Resource Type '$resourceType' is not supported.", ctx)
+ }
+ }
+
+ // Extract database, name & alias.
+ val (database, function) = visitFunctionName(ctx.qualifiedName)
+ CreateFunction(
+ database,
+ function,
+ string(ctx.className),
+ resources,
+ ctx.TEMPORARY != null)
+ }
+
+ /**
+ * Create a [[DropFunction]] command.
+ *
+ * For example:
+ * {{{
+ * DROP [TEMPORARY] FUNCTION [IF EXISTS] function;
+ * }}}
+ */
+ override def visitDropFunction(ctx: DropFunctionContext): LogicalPlan = withOrigin(ctx) {
+ val (database, function) = visitFunctionName(ctx.qualifiedName)
+ DropFunction(database, function, ctx.EXISTS != null, ctx.TEMPORARY != null)
+ }
+
+ /**
+ * Create a function database (optional) and name pair.
+ */
+ private def visitFunctionName(ctx: QualifiedNameContext): (Option[String], String) = {
+ ctx.identifier().asScala.map(_.getText) match {
+ case Seq(db, fn) => (Option(db), fn)
+ case Seq(fn) => (None, fn)
+ case other => throw new ParseException(s"Unsupported function name '${ctx.getText}'", ctx)
+ }
+ }
+
+ /**
+ * Create a [[DropTable]] command.
+ */
+ override def visitDropTable(ctx: DropTableContext): LogicalPlan = withOrigin(ctx) {
+ if (ctx.PURGE != null) {
+ throw new ParseException("Unsupported operation: PURGE option", ctx)
+ }
+ if (ctx.REPLICATION != null) {
+ throw new ParseException("Unsupported operation: REPLICATION clause", ctx)
+ }
+ DropTable(
+ visitTableIdentifier(ctx.tableIdentifier),
+ ctx.EXISTS != null,
+ ctx.VIEW != null)
+ }
+
+ /**
+ * Create a [[AlterTableRename]] command.
+ *
+ * For example:
+ * {{{
+ * ALTER TABLE table1 RENAME TO table2;
+ * ALTER VIEW view1 RENAME TO view2;
+ * }}}
+ */
+ override def visitRenameTable(ctx: RenameTableContext): LogicalPlan = withOrigin(ctx) {
+ AlterTableRename(
+ visitTableIdentifier(ctx.from),
+ visitTableIdentifier(ctx.to),
+ ctx.VIEW != null)
+ }
+
+ /**
+ * Create an [[AlterTableSetProperties]] command.
+ *
+ * For example:
+ * {{{
+ * ALTER TABLE table SET TBLPROPERTIES ('comment' = new_comment);
+ * ALTER VIEW view SET TBLPROPERTIES ('comment' = new_comment);
+ * }}}
+ */
+ override def visitSetTableProperties(
+ ctx: SetTablePropertiesContext): LogicalPlan = withOrigin(ctx) {
+ AlterTableSetProperties(
+ visitTableIdentifier(ctx.tableIdentifier),
+ visitTablePropertyList(ctx.tablePropertyList),
+ ctx.VIEW != null)
+ }
+
+ /**
+ * Create an [[AlterTableUnsetProperties]] command.
+ *
+ * For example:
+ * {{{
+ * ALTER TABLE table UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key');
+ * ALTER VIEW view UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key');
+ * }}}
+ */
+ override def visitUnsetTableProperties(
+ ctx: UnsetTablePropertiesContext): LogicalPlan = withOrigin(ctx) {
+ AlterTableUnsetProperties(
+ visitTableIdentifier(ctx.tableIdentifier),
+ visitTablePropertyList(ctx.tablePropertyList).keys.toSeq,
+ ctx.EXISTS != null,
+ ctx.VIEW != null)
+ }
+
+ /**
+ * Create an [[AlterTableSerDeProperties]] command.
+ *
+ * For example:
+ * {{{
+ * ALTER TABLE table [PARTITION spec] SET SERDE serde_name [WITH SERDEPROPERTIES props];
+ * ALTER TABLE table [PARTITION spec] SET SERDEPROPERTIES serde_properties;
+ * }}}
+ */
+ override def visitSetTableSerDe(ctx: SetTableSerDeContext): LogicalPlan = withOrigin(ctx) {
+ AlterTableSerDeProperties(
+ visitTableIdentifier(ctx.tableIdentifier),
+ Option(ctx.STRING).map(string),
+ Option(ctx.tablePropertyList).map(visitTablePropertyList),
+ // TODO a partition spec is allowed to have optional values. This is currently violated.
+ Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec))
+ }
+
+ // TODO: don't even bother parsing alter table commands related to bucketing and skewing
+
+ override def visitBucketTable(ctx: BucketTableContext): LogicalPlan = withOrigin(ctx) {
+ throw new AnalysisException(
+ "Operation not allowed: ALTER TABLE ... CLUSTERED BY ... INTO N BUCKETS")
+ }
+
+ override def visitUnclusterTable(ctx: UnclusterTableContext): LogicalPlan = withOrigin(ctx) {
+ throw new AnalysisException("Operation not allowed: ALTER TABLE ... NOT CLUSTERED")
+ }
+
+ override def visitUnsortTable(ctx: UnsortTableContext): LogicalPlan = withOrigin(ctx) {
+ throw new AnalysisException("Operation not allowed: ALTER TABLE ... NOT SORTED")
+ }
+
+ override def visitSkewTable(ctx: SkewTableContext): LogicalPlan = withOrigin(ctx) {
+ throw new AnalysisException("Operation not allowed: ALTER TABLE ... SKEWED BY ...")
+ }
+
+ override def visitUnskewTable(ctx: UnskewTableContext): LogicalPlan = withOrigin(ctx) {
+ throw new AnalysisException("Operation not allowed: ALTER TABLE ... NOT SKEWED")
+ }
+
+ override def visitUnstoreTable(ctx: UnstoreTableContext): LogicalPlan = withOrigin(ctx) {
+ throw new AnalysisException(
+ "Operation not allowed: ALTER TABLE ... NOT STORED AS DIRECTORIES")
+ }
+
+ override def visitSetTableSkewLocations(
+ ctx: SetTableSkewLocationsContext): LogicalPlan = withOrigin(ctx) {
+ throw new AnalysisException(
+ "Operation not allowed: ALTER TABLE ... SET SKEWED LOCATION ...")
+ }
+
+ /**
+ * Create an [[AlterTableAddPartition]] command.
+ *
+ * For example:
+ * {{{
+ * ALTER TABLE table ADD [IF NOT EXISTS] PARTITION spec [LOCATION 'loc1']
+ * ALTER VIEW view ADD [IF NOT EXISTS] PARTITION spec
+ * }}}
+ *
+ * ALTER VIEW ... ADD PARTITION ... is not supported because the concept of partitioning
+ * is associated with physical tables
+ */
+ override def visitAddTablePartition(
+ ctx: AddTablePartitionContext): LogicalPlan = withOrigin(ctx) {
+ if (ctx.VIEW != null) {
+ throw new AnalysisException(s"Operation not allowed: partitioned views")
+ }
+ // Create partition spec to location mapping.
+ val specsAndLocs = if (ctx.partitionSpec.isEmpty) {
+ ctx.partitionSpecLocation.asScala.map {
+ splCtx =>
+ val spec = visitNonOptionalPartitionSpec(splCtx.partitionSpec)
+ val location = Option(splCtx.locationSpec).map(visitLocationSpec)
+ spec -> location
+ }
+ } else {
+ // Alter View: the location clauses are not allowed.
+ ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec(_) -> None)
+ }
+ AlterTableAddPartition(
+ visitTableIdentifier(ctx.tableIdentifier),
+ specsAndLocs,
+ ctx.EXISTS != null)
+ }
+
+ /**
+ * Create an [[AlterTableExchangePartition]] command.
+ *
+ * For example:
+ * {{{
+ * ALTER TABLE table1 EXCHANGE PARTITION spec WITH TABLE table2;
+ * }}}
+ */
+ override def visitExchangeTablePartition(
+ ctx: ExchangeTablePartitionContext): LogicalPlan = withOrigin(ctx) {
+ throw new AnalysisException(
+ "Operation not allowed: ALTER TABLE ... EXCHANGE PARTITION ...")
+ }
+
+ /**
+ * Create an [[AlterTableRenamePartition]] command
+ *
+ * For example:
+ * {{{
+ * ALTER TABLE table PARTITION spec1 RENAME TO PARTITION spec2;
+ * }}}
+ */
+ override def visitRenameTablePartition(
+ ctx: RenameTablePartitionContext): LogicalPlan = withOrigin(ctx) {
+ AlterTableRenamePartition(
+ visitTableIdentifier(ctx.tableIdentifier),
+ visitNonOptionalPartitionSpec(ctx.from),
+ visitNonOptionalPartitionSpec(ctx.to))
+ }
+
+ /**
+ * Create an [[AlterTableDropPartition]] command
+ *
+ * For example:
+ * {{{
+ * ALTER TABLE table DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE];
+ * ALTER VIEW view DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...];
+ * }}}
+ *
+ * ALTER VIEW ... DROP PARTITION ... is not supported because the concept of partitioning
+ * is associated with physical tables
+ */
+ override def visitDropTablePartitions(
+ ctx: DropTablePartitionsContext): LogicalPlan = withOrigin(ctx) {
+ if (ctx.VIEW != null) {
+ throw new AnalysisException(s"Operation not allowed: partitioned views")
+ }
+ if (ctx.PURGE != null) {
+ throw new AnalysisException(s"Operation not allowed: PURGE")
+ }
+ AlterTableDropPartition(
+ visitTableIdentifier(ctx.tableIdentifier),
+ ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec),
+ ctx.EXISTS != null)
+ }
+
+ /**
+ * Create an [[AlterTableArchivePartition]] command
+ *
+ * For example:
+ * {{{
+ * ALTER TABLE table ARCHIVE PARTITION spec;
+ * }}}
+ */
+ override def visitArchiveTablePartition(
+ ctx: ArchiveTablePartitionContext): LogicalPlan = withOrigin(ctx) {
+ throw new AnalysisException(
+ "Operation not allowed: ALTER TABLE ... ARCHIVE PARTITION ...")
+ }
+
+ /**
+ * Create an [[AlterTableUnarchivePartition]] command
+ *
+ * For example:
+ * {{{
+ * ALTER TABLE table UNARCHIVE PARTITION spec;
+ * }}}
+ */
+ override def visitUnarchiveTablePartition(
+ ctx: UnarchiveTablePartitionContext): LogicalPlan = withOrigin(ctx) {
+ throw new AnalysisException(
+ "Operation not allowed: ALTER TABLE ... UNARCHIVE PARTITION ...")
+ }
+
+ /**
+ * Create an [[AlterTableSetFileFormat]] command
+ *
+ * For example:
+ * {{{
+ * ALTER TABLE table [PARTITION spec] SET FILEFORMAT file_format;
+ * }}}
+ */
+ override def visitSetTableFileFormat(
+ ctx: SetTableFileFormatContext): LogicalPlan = withOrigin(ctx) {
+ // AlterTableSetFileFormat currently takes both a GenericFileFormat and a
+ // TableFileFormatContext. This is a bit weird because it should only take one. It also should
+ // use a CatalogFileFormat instead of either a String or a Sequence of Strings. We will address
+ // this in a follow-up PR.
+ val (fileFormat, genericFormat) = ctx.fileFormat match {
+ case s: GenericFileFormatContext =>
+ (Seq.empty[String], Option(s.identifier.getText))
+ case s: TableFileFormatContext =>
+ val elements = Seq(s.inFmt, s.outFmt) ++ Option(s.serdeCls).toSeq
+ (elements.map(string), None)
+ }
+ AlterTableSetFileFormat(
+ visitTableIdentifier(ctx.tableIdentifier),
+ Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec),
+ fileFormat,
+ genericFormat)(
+ command(ctx))
+ }
+
+ /**
+ * Create an [[AlterTableSetLocation]] command
+ *
+ * For example:
+ * {{{
+ * ALTER TABLE table [PARTITION spec] SET LOCATION "loc";
+ * }}}
+ */
+ override def visitSetTableLocation(ctx: SetTableLocationContext): LogicalPlan = withOrigin(ctx) {
+ AlterTableSetLocation(
+ visitTableIdentifier(ctx.tableIdentifier),
+ Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec),
+ visitLocationSpec(ctx.locationSpec))
+ }
+
+ /**
+ * Create an [[AlterTableTouch]] command
+ *
+ * For example:
+ * {{{
+ * ALTER TABLE table TOUCH [PARTITION spec];
+ * }}}
+ */
+ override def visitTouchTable(ctx: TouchTableContext): LogicalPlan = withOrigin(ctx) {
+ throw new AnalysisException("Operation not allowed: ALTER TABLE ... TOUCH ...")
+ }
+
+ /**
+ * Create an [[AlterTableCompact]] command
+ *
+ * For example:
+ * {{{
+ * ALTER TABLE table [PARTITION spec] COMPACT 'compaction_type';
+ * }}}
+ */
+ override def visitCompactTable(ctx: CompactTableContext): LogicalPlan = withOrigin(ctx) {
+ throw new AnalysisException("Operation not allowed: ALTER TABLE ... COMPACT ...")
+ }
+
+ /**
+ * Create an [[AlterTableMerge]] command
+ *
+ * For example:
+ * {{{
+ * ALTER TABLE table [PARTITION spec] CONCATENATE;
+ * }}}
+ */
+ override def visitConcatenateTable(ctx: ConcatenateTableContext): LogicalPlan = withOrigin(ctx) {
+ throw new AnalysisException("Operation not allowed: ALTER TABLE ... CONCATENATE")
+ }
+
+ /**
+ * Create an [[AlterTableChangeCol]] command
+ *
+ * For example:
+ * {{{
+ * ALTER TABLE tableIdentifier [PARTITION spec]
+ * CHANGE [COLUMN] col_old_name col_new_name column_type [COMMENT col_comment]
+ * [FIRST|AFTER column_name] [CASCADE|RESTRICT];
+ * }}}
+ */
+ override def visitChangeColumn(ctx: ChangeColumnContext): LogicalPlan = withOrigin(ctx) {
+ val col = visitColType(ctx.colType())
+ val comment = if (col.metadata.contains("comment")) {
+ Option(col.metadata.getString("comment"))
+ } else {
+ None
+ }
+
+ AlterTableChangeCol(
+ visitTableIdentifier(ctx.tableIdentifier),
+ Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec),
+ ctx.oldName.getText,
+ // We could also pass in a struct field - seems easier.
+ col.name,
+ col.dataType,
+ comment,
+ Option(ctx.after).map(_.getText),
+ // Note that Restrict and Cascade are mutually exclusive.
+ ctx.RESTRICT != null,
+ ctx.CASCADE != null)(
+ command(ctx))
+ }
+
+ /**
+ * Create an [[AlterTableAddCol]] command
+ *
+ * For example:
+ * {{{
+ * ALTER TABLE tableIdentifier [PARTITION spec]
+ * ADD COLUMNS (name type [COMMENT comment], ...) [CASCADE|RESTRICT]
+ * }}}
+ */
+ override def visitAddColumns(ctx: AddColumnsContext): LogicalPlan = withOrigin(ctx) {
+ AlterTableAddCol(
+ visitTableIdentifier(ctx.tableIdentifier),
+ Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec),
+ createStructType(ctx.colTypeList),
+ // Note that Restrict and Cascade are mutually exclusive.
+ ctx.RESTRICT != null,
+ ctx.CASCADE != null)(
+ command(ctx))
+ }
+
+ /**
+ * Create an [[AlterTableReplaceCol]] command
+ *
+ * For example:
+ * {{{
+ * ALTER TABLE tableIdentifier [PARTITION spec]
+ * REPLACE COLUMNS (name type [COMMENT comment], ...) [CASCADE|RESTRICT]
+ * }}}
+ */
+ override def visitReplaceColumns(ctx: ReplaceColumnsContext): LogicalPlan = withOrigin(ctx) {
+ AlterTableReplaceCol(
+ visitTableIdentifier(ctx.tableIdentifier),
+ Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec),
+ createStructType(ctx.colTypeList),
+ // Note that Restrict and Cascade are mutually exclusive.
+ ctx.RESTRICT != null,
+ ctx.CASCADE != null)(
+ command(ctx))
+ }
+
+ /**
+ * Create location string.
+ */
+ override def visitLocationSpec(ctx: LocationSpecContext): String = withOrigin(ctx) {
+ string(ctx.STRING)
+ }
+
+ /**
+ * Create a [[BucketSpec]].
+ */
+ override def visitBucketSpec(ctx: BucketSpecContext): BucketSpec = withOrigin(ctx) {
+ BucketSpec(
+ ctx.INTEGER_VALUE.getText.toInt,
+ visitIdentifierList(ctx.identifierList),
+ Option(ctx.orderedIdentifierList).toSeq
+ .flatMap(_.orderedIdentifier.asScala)
+ .map(_.identifier.getText))
+ }
+
+ /**
+ * Convert a nested constants list into a sequence of string sequences.
+ */
+ override def visitNestedConstantList(
+ ctx: NestedConstantListContext): Seq[Seq[String]] = withOrigin(ctx) {
+ ctx.constantList.asScala.map(visitConstantList)
+ }
+
+ /**
+ * Convert a constants list into a String sequence.
+ */
+ override def visitConstantList(ctx: ConstantListContext): Seq[String] = withOrigin(ctx) {
+ ctx.constant.asScala.map(visitStringConstant)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 7841ff01f9..c15aaed365 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -19,8 +19,8 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan}
@@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.command.{DescribeCommand => RunnableDescri
import org.apache.spark.sql.execution.datasources.{DescribeCommand => LogicalDescribeCommand, _}
import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight}
+import org.apache.spark.sql.execution.streaming.MemoryPlan
import org.apache.spark.sql.internal.SQLConf
private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
@@ -61,16 +62,17 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}
- object LeftSemiJoin extends Strategy with PredicateHelper {
+ object ExistenceJoin extends Strategy with PredicateHelper {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ExtractEquiJoinKeys(
- LeftSemi, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
+ LeftExistence(jt), leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
Seq(joins.BroadcastHashJoin(
- leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right)))
+ leftKeys, rightKeys, jt, BuildRight, condition, planLater(left), planLater(right)))
// Find left semi joins where at least some predicates can be evaluated by matching join keys
- case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
+ case ExtractEquiJoinKeys(
+ LeftExistence(jt), leftKeys, rightKeys, condition, left, right) =>
Seq(joins.ShuffledHashJoin(
- leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right)))
+ leftKeys, rightKeys, jt, BuildRight, condition, planLater(left), planLater(right)))
case _ => Nil
}
}
@@ -109,7 +111,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
/**
* Matches a plan whose single partition should be small enough to build a hash table.
*
- * Note: this assume that the number of partition is fixed, requires addtional work if it's
+ * Note: this assume that the number of partition is fixed, requires additional work if it's
* dynamic.
*/
def canBuildHashMap(plan: LogicalPlan): Boolean = {
@@ -204,28 +206,32 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
/**
+ * Used to plan aggregation queries that are computed incrementally as part of a
+ * [[org.apache.spark.sql.ContinuousQuery]]. Currently this rule is injected into the planner
+ * on-demand, only when planning in a [[org.apache.spark.sql.execution.streaming.StreamExecution]]
+ */
+ object StatefulAggregationStrategy extends Strategy {
+ override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case PhysicalAggregation(
+ namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) =>
+
+ aggregate.Utils.planStreamingAggregation(
+ namedGroupingExpressions,
+ aggregateExpressions,
+ rewrittenResultExpressions,
+ planLater(child))
+
+ case _ => Nil
+ }
+ }
+
+ /**
* Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface.
*/
object Aggregation extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
- // A single aggregate expression might appear multiple times in resultExpressions.
- // In order to avoid evaluating an individual aggregate function multiple times, we'll
- // build a set of the distinct aggregate expressions and build a function which can
- // be used to re-write expressions so that they reference the single copy of the
- // aggregate function which actually gets computed.
- val aggregateExpressions = resultExpressions.flatMap { expr =>
- expr.collect {
- case agg: AggregateExpression => agg
- }
- }.distinct
- // For those distinct aggregate expressions, we create a map from the
- // aggregate function to the corresponding attribute of the function.
- val aggregateFunctionToAttribute = aggregateExpressions.map { agg =>
- val aggregateFunction = agg.aggregateFunction
- val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
- (aggregateFunction, agg.isDistinct) -> attribute
- }.toMap
+ case PhysicalAggregation(
+ groupingExpressions, aggregateExpressions, resultExpressions, child) =>
val (functionsWithDistinct, functionsWithoutDistinct) =
aggregateExpressions.partition(_.isDistinct)
@@ -233,41 +239,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// This is a sanity check. We should not reach here when we have multiple distinct
// column sets. Our MultipleDistinctRewriter should take care this case.
sys.error("You hit a query analyzer bug. Please report your query to " +
- "Spark user mailing list.")
- }
-
- val namedGroupingExpressions = groupingExpressions.map {
- case ne: NamedExpression => ne -> ne
- // If the expression is not a NamedExpressions, we add an alias.
- // So, when we generate the result of the operator, the Aggregate Operator
- // can directly get the Seq of attributes representing the grouping expressions.
- case other =>
- val withAlias = Alias(other, other.toString)()
- other -> withAlias
- }
- val groupExpressionMap = namedGroupingExpressions.toMap
-
- // The original `resultExpressions` are a set of expressions which may reference
- // aggregate expressions, grouping column values, and constants. When aggregate operator
- // emits output rows, we will use `resultExpressions` to generate an output projection
- // which takes the grouping columns and final aggregate result buffer as input.
- // Thus, we must re-write the result expressions so that their attributes match up with
- // the attributes of the final result projection's input row:
- val rewrittenResultExpressions = resultExpressions.map { expr =>
- expr.transformDown {
- case AggregateExpression(aggregateFunction, _, isDistinct) =>
- // The final aggregation buffer's attributes will be `finalAggregationAttributes`,
- // so replace each aggregate expression by its corresponding attribute in the set:
- aggregateFunctionToAttribute(aggregateFunction, isDistinct)
- case expression =>
- // Since we're using `namedGroupingAttributes` to extract the grouping key
- // columns, we need to replace grouping key expressions with their corresponding
- // attributes. We do not rely on the equality check at here since attributes may
- // differ cosmetically. Instead, we use semanticEquals.
- groupExpressionMap.collectFirst {
- case (expr, ne) if expr semanticEquals expression => ne.toAttribute
- }.getOrElse(expression)
- }.asInstanceOf[NamedExpression]
+ "Spark user mailing list.")
}
val aggregateOperator =
@@ -277,26 +249,23 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
"aggregate functions which don't support partial aggregation.")
} else {
aggregate.Utils.planAggregateWithoutPartial(
- namedGroupingExpressions.map(_._2),
+ groupingExpressions,
aggregateExpressions,
- aggregateFunctionToAttribute,
- rewrittenResultExpressions,
+ resultExpressions,
planLater(child))
}
} else if (functionsWithDistinct.isEmpty) {
aggregate.Utils.planAggregateWithoutDistinct(
- namedGroupingExpressions.map(_._2),
+ groupingExpressions,
aggregateExpressions,
- aggregateFunctionToAttribute,
- rewrittenResultExpressions,
+ resultExpressions,
planLater(child))
} else {
aggregate.Utils.planAggregateWithOneDistinct(
- namedGroupingExpressions.map(_._2),
+ groupingExpressions,
functionsWithDistinct,
functionsWithoutDistinct,
- aggregateFunctionToAttribute,
- rewrittenResultExpressions,
+ resultExpressions,
planLater(child))
}
@@ -366,6 +335,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case r: RunnableCommand => ExecutedCommand(r) :: Nil
+ case MemoryPlan(sink, output) =>
+ val encoder = RowEncoder(sink.schema)
+ LocalTableScan(output, sink.allData.map(r => encoder.toRow(r).copy())) :: Nil
+
case logical.Distinct(child) =>
throw new IllegalStateException(
"logical distinct operator should have been replaced by aggregate in the optimizer")
@@ -373,8 +346,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
throw new IllegalStateException(
"logical intersect operator should have been replaced by semi-join in the optimizer")
+ case logical.DeserializeToObject(deserializer, child) =>
+ execution.DeserializeToObject(deserializer, planLater(child)) :: Nil
+ case logical.SerializeFromObject(serializer, child) =>
+ execution.SerializeFromObject(serializer, planLater(child)) :: Nil
case logical.MapPartitions(f, in, out, child) =>
execution.MapPartitions(f, in, out, planLater(child)) :: Nil
+ case logical.MapElements(f, in, out, child) =>
+ execution.MapElements(f, in, out, planLater(child)) :: Nil
case logical.AppendColumns(f, in, out, child) =>
execution.AppendColumns(f, in, out, planLater(child)) :: Nil
case logical.MapGroups(f, key, in, out, grouping, data, child) =>
@@ -426,8 +405,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.RepartitionByExpression(expressions, child, nPartitions) =>
exchange.ShuffleExchange(HashPartitioning(
expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
- case e @ python.EvaluatePython(udf, child, _) =>
- python.BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil
case BroadcastHint(child) => planLater(child) :: Nil
case _ => Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 1b13c8fd22..447dbe7018 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution
-import org.apache.spark.broadcast
+import org.apache.spark.{broadcast, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
@@ -29,10 +29,11 @@ import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin}
import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
/**
- * An interface for those physical operators that support codegen.
- */
+ * An interface for those physical operators that support codegen.
+ */
trait CodegenSupport extends SparkPlan {
/** Prefix used in the current operator's variable names. */
@@ -46,10 +47,10 @@ trait CodegenSupport extends SparkPlan {
}
/**
- * Creates a metric using the specified name.
- *
- * @return name of the variable representing the metric
- */
+ * Creates a metric using the specified name.
+ *
+ * @return name of the variable representing the metric
+ */
def metricTerm(ctx: CodegenContext, name: String): String = {
val metric = ctx.addReferenceObj(name, longMetric(name))
val value = ctx.freshName("metricValue")
@@ -59,25 +60,25 @@ trait CodegenSupport extends SparkPlan {
}
/**
- * Whether this SparkPlan support whole stage codegen or not.
- */
+ * Whether this SparkPlan support whole stage codegen or not.
+ */
def supportCodegen: Boolean = true
/**
- * Which SparkPlan is calling produce() of this one. It's itself for the first SparkPlan.
- */
+ * Which SparkPlan is calling produce() of this one. It's itself for the first SparkPlan.
+ */
protected var parent: CodegenSupport = null
/**
- * Returns all the RDDs of InternalRow which generates the input rows.
- *
- * Note: right now we support up to two RDDs.
- */
+ * Returns all the RDDs of InternalRow which generates the input rows.
+ *
+ * Note: right now we support up to two RDDs.
+ */
def upstreams(): Seq[RDD[InternalRow]]
/**
- * Returns Java source code to process the rows from upstream.
- */
+ * Returns Java source code to process the rows from upstream.
+ */
final def produce(ctx: CodegenContext, parent: CodegenSupport): String = {
this.parent = parent
ctx.freshNamePrefix = variablePrefix
@@ -89,28 +90,28 @@ trait CodegenSupport extends SparkPlan {
}
/**
- * Generate the Java source code to process, should be overridden by subclass to support codegen.
- *
- * doProduce() usually generate the framework, for example, aggregation could generate this:
- *
- * if (!initialized) {
- * # create a hash map, then build the aggregation hash map
- * # call child.produce()
- * initialized = true;
- * }
- * while (hashmap.hasNext()) {
- * row = hashmap.next();
- * # build the aggregation results
- * # create variables for results
- * # call consume(), which will call parent.doConsume()
+ * Generate the Java source code to process, should be overridden by subclass to support codegen.
+ *
+ * doProduce() usually generate the framework, for example, aggregation could generate this:
+ *
+ * if (!initialized) {
+ * # create a hash map, then build the aggregation hash map
+ * # call child.produce()
+ * initialized = true;
+ * }
+ * while (hashmap.hasNext()) {
+ * row = hashmap.next();
+ * # build the aggregation results
+ * # create variables for results
+ * # call consume(), which will call parent.doConsume()
* if (shouldStop()) return;
- * }
- */
+ * }
+ */
protected def doProduce(ctx: CodegenContext): String
/**
- * Consume the generated columns or row from current SparkPlan, call it's parent's doConsume().
- */
+ * Consume the generated columns or row from current SparkPlan, call it's parent's doConsume().
+ */
final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String = {
val inputVars =
if (row != null) {
@@ -152,15 +153,15 @@ trait CodegenSupport extends SparkPlan {
s"""
|
|/*** CONSUME: ${toCommentSafeString(parent.simpleString)} */
- |${evaluated}
+ |$evaluated
|${parent.doConsume(ctx, inputVars, rowVar)}
""".stripMargin
}
/**
- * Returns source code to evaluate all the variables, and clear the code of them, to prevent
- * them to be evaluated twice.
- */
+ * Returns source code to evaluate all the variables, and clear the code of them, to prevent
+ * them to be evaluated twice.
+ */
protected def evaluateVariables(variables: Seq[ExprCode]): String = {
val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n")
variables.foreach(_.code = "")
@@ -168,21 +169,21 @@ trait CodegenSupport extends SparkPlan {
}
/**
- * Returns source code to evaluate the variables for required attributes, and clear the code
- * of evaluated variables, to prevent them to be evaluated twice..
- */
+ * Returns source code to evaluate the variables for required attributes, and clear the code
+ * of evaluated variables, to prevent them to be evaluated twice.
+ */
protected def evaluateRequiredVariables(
attributes: Seq[Attribute],
variables: Seq[ExprCode],
required: AttributeSet): String = {
- var evaluateVars = ""
+ val evaluateVars = new StringBuilder
variables.zipWithIndex.foreach { case (ev, i) =>
if (ev.code != "" && required.contains(attributes(i))) {
- evaluateVars += ev.code.trim + "\n"
+ evaluateVars.append(ev.code.trim + "\n")
ev.code = ""
}
}
- evaluateVars
+ evaluateVars.toString()
}
/**
@@ -194,18 +195,18 @@ trait CodegenSupport extends SparkPlan {
def usedInputs: AttributeSet = references
/**
- * Generate the Java source code to process the rows from child SparkPlan.
- *
- * This should be override by subclass to support codegen.
- *
- * For example, Filter will generate the code like this:
- *
- * # code to evaluate the predicate expression, result is isNull1 and value2
- * if (isNull1 || !value2) continue;
- * # call consume(), which will call parent.doConsume()
- *
- * Note: A plan can either consume the rows as UnsafeRow (row), or a list of variables (input).
- */
+ * Generate the Java source code to process the rows from child SparkPlan.
+ *
+ * This should be override by subclass to support codegen.
+ *
+ * For example, Filter will generate the code like this:
+ *
+ * # code to evaluate the predicate expression, result is isNull1 and value2
+ * if (isNull1 || !value2) continue;
+ * # call consume(), which will call parent.doConsume()
+ *
+ * Note: A plan can either consume the rows as UnsafeRow (row), or a list of variables (input).
+ */
def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
throw new UnsupportedOperationException
}
@@ -213,11 +214,11 @@ trait CodegenSupport extends SparkPlan {
/**
- * InputAdapter is used to hide a SparkPlan from a subtree that support codegen.
- *
- * This is the leaf node of a tree with WholeStageCodegen, is used to generate code that consumes
- * an RDD iterator of InternalRow.
- */
+ * InputAdapter is used to hide a SparkPlan from a subtree that support codegen.
+ *
+ * This is the leaf node of a tree with WholeStageCodegen, is used to generate code that consumes
+ * an RDD iterator of InternalRow.
+ */
case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport {
override def output: Seq[Attribute] = child.output
@@ -260,33 +261,33 @@ object WholeStageCodegen {
}
/**
- * WholeStageCodegen compile a subtree of plans that support codegen together into single Java
- * function.
- *
- * Here is the call graph of to generate Java source (plan A support codegen, but plan B does not):
- *
- * WholeStageCodegen Plan A FakeInput Plan B
- * =========================================================================
- *
- * -> execute()
- * |
- * doExecute() ---------> upstreams() -------> upstreams() ------> execute()
- * |
- * +-----------------> produce()
- * |
- * doProduce() -------> produce()
- * |
- * doProduce()
- * |
- * doConsume() <--------- consume()
- * |
- * doConsume() <-------- consume()
- *
- * SparkPlan A should override doProduce() and doConsume().
- *
- * doCodeGen() will create a CodeGenContext, which will hold a list of variables for input,
- * used to generated code for BoundReference.
- */
+ * WholeStageCodegen compile a subtree of plans that support codegen together into single Java
+ * function.
+ *
+ * Here is the call graph of to generate Java source (plan A support codegen, but plan B does not):
+ *
+ * WholeStageCodegen Plan A FakeInput Plan B
+ * =========================================================================
+ *
+ * -> execute()
+ * |
+ * doExecute() ---------> upstreams() -------> upstreams() ------> execute()
+ * |
+ * +-----------------> produce()
+ * |
+ * doProduce() -------> produce()
+ * |
+ * doProduce()
+ * |
+ * doConsume() <--------- consume()
+ * |
+ * doConsume() <-------- consume()
+ *
+ * SparkPlan A should override doProduce() and doConsume().
+ *
+ * doCodeGen() will create a CodeGenContext, which will hold a list of variables for input,
+ * used to generated code for BoundReference.
+ */
case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSupport {
override def output: Seq[Attribute] = child.output
@@ -297,18 +298,22 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
"pipelineTime" -> SQLMetrics.createTimingMetric(sparkContext,
WholeStageCodegen.PIPELINE_DURATION_METRIC))
- override def doExecute(): RDD[InternalRow] = {
+ /**
+ * Generates code for this subtree.
+ *
+ * @return the tuple of the codegen context and the actual generated source.
+ */
+ def doCodeGen(): (CodegenContext, String) = {
val ctx = new CodegenContext
val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
- val references = ctx.references.toArray
val source = s"""
public Object generate(Object[] references) {
return new GeneratedIterator(references);
}
/** Codegened pipeline for:
- * ${toCommentSafeString(child.treeString.trim)}
- */
+ * ${toCommentSafeString(child.treeString.trim)}
+ */
final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
private Object[] references;
@@ -318,7 +323,8 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
this.references = references;
}
- public void init(scala.collection.Iterator inputs[]) {
+ public void init(int index, scala.collection.Iterator inputs[]) {
+ partitionIndex = index;
${ctx.initMutableStates()}
}
@@ -332,18 +338,24 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
// try to compile, helpful for debug
val cleanedSource = CodeFormatter.stripExtraNewLines(source)
- logDebug(s"${CodeFormatter.format(cleanedSource)}")
+ logDebug(s"\n${CodeFormatter.format(cleanedSource)}")
CodeGenerator.compile(cleanedSource)
+ (ctx, cleanedSource)
+ }
+
+ override def doExecute(): RDD[InternalRow] = {
+ val (ctx, cleanedSource) = doCodeGen()
+ val references = ctx.references.toArray
val durationMs = longMetric("pipelineTime")
val rdds = child.asInstanceOf[CodegenSupport].upstreams()
assert(rdds.size <= 2, "Up to two upstream RDDs can be supported")
if (rdds.length == 1) {
- rdds.head.mapPartitions { iter =>
+ rdds.head.mapPartitionsWithIndex { (index, iter) =>
val clazz = CodeGenerator.compile(cleanedSource)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
- buffer.init(Array(iter))
+ buffer.init(index, Array(iter))
new Iterator[InternalRow] {
override def hasNext: Boolean = {
val v = buffer.hasNext
@@ -356,9 +368,10 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
} else {
// Right now, we support up to two upstreams.
rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) =>
+ val partitionIndex = TaskContext.getPartitionId()
val clazz = CodeGenerator.compile(cleanedSource)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
- buffer.init(Array(leftIter, rightIter))
+ buffer.init(partitionIndex, Array(leftIter, rightIter))
new Iterator[InternalRow] {
override def hasNext: Boolean = {
val v = buffer.hasNext
@@ -409,8 +422,8 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
/**
- * Find the chained plans that support codegen, collapse them together as WholeStageCodegen.
- */
+ * Find the chained plans that support codegen, collapse them together as WholeStageCodegen.
+ */
case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {
private def supportCodegen(e: Expression): Boolean = e match {
@@ -421,12 +434,23 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {
case _ => true
}
+ private def numOfNestedFields(dataType: DataType): Int = dataType match {
+ case dt: StructType => dt.fields.map(f => numOfNestedFields(f.dataType)).sum
+ case m: MapType => numOfNestedFields(m.keyType) + numOfNestedFields(m.valueType)
+ case a: ArrayType => numOfNestedFields(a.elementType)
+ case u: UserDefinedType[_] => numOfNestedFields(u.sqlType)
+ case _ => 1
+ }
+
private def supportCodegen(plan: SparkPlan): Boolean = plan match {
case plan: CodegenSupport if plan.supportCodegen =>
val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined)
// the generated code will be huge if there are too many columns
- val haveManyColumns = plan.output.length > 200
- !willFallback && !haveManyColumns
+ val hasTooManyOutputFields =
+ numOfNestedFields(plan.schema) > conf.wholeStageMaxNumFields
+ val hasTooManyInputFields =
+ plan.children.map(p => numOfNestedFields(p.schema)).exists(_ > conf.wholeStageMaxNumFields)
+ !willFallback && !hasTooManyOutputFields && !hasTooManyInputFields
case _ => false
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
index 270c09aff3..8e9214fa25 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
@@ -177,7 +177,7 @@ case class Window(
case e @ WindowExpression(function, spec) =>
val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
function match {
- case AggregateExpression(f, _, _) => collect("AGGREGATE", frame, e, f)
+ case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f)
case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f)
case f: OffsetWindowFunction => collect("OFFSET", frame, e, f)
case f => sys.error(s"Unsupported window function: $f")
@@ -444,8 +444,8 @@ private[execution] final case class RangeBoundOrdering(
}
/**
- * The interface of row buffer for a partition
- */
+ * The interface of row buffer for a partition
+ */
private[execution] abstract class RowBuffer {
/** Number of rows. */
@@ -462,8 +462,8 @@ private[execution] abstract class RowBuffer {
}
/**
- * A row buffer based on ArrayBuffer (the number of rows is limited)
- */
+ * A row buffer based on ArrayBuffer (the number of rows is limited)
+ */
private[execution] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends RowBuffer {
private[this] var cursor: Int = -1
@@ -493,8 +493,8 @@ private[execution] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends
}
/**
- * An external buffer of rows based on UnsafeExternalSorter
- */
+ * An external buffer of rows based on UnsafeExternalSorter
+ */
private[execution] class ExternalRowBuffer(sorter: UnsafeExternalSorter, numFields: Int)
extends RowBuffer {
@@ -654,12 +654,16 @@ private[execution] final class SlidingWindowFunctionFrame(
/** The rows within current sliding window. */
private[this] val buffer = new util.ArrayDeque[InternalRow]()
- /** Index of the first input row with a value greater than the upper bound of the current
- * output row. */
+ /**
+ * Index of the first input row with a value greater than the upper bound of the current
+ * output row.
+ */
private[this] var inputHighIndex = 0
- /** Index of the first input row with a value equal to or greater than the lower bound of the
- * current output row. */
+ /**
+ * Index of the first input row with a value equal to or greater than the lower bound of the
+ * current output row.
+ */
private[this] var inputLowIndex = 0
/** Prepare the frame for calculating a new partition. Reset all variables. */
@@ -763,8 +767,10 @@ private[execution] final class UnboundedPrecedingWindowFunctionFrame(
/** The next row from `input`. */
private[this] var nextRow: InternalRow = null
- /** Index of the first input row with a value greater than the upper bound of the current
- * output row. */
+ /**
+ * Index of the first input row with a value greater than the upper bound of the current
+ * output row.
+ */
private[this] var inputIndex = 0
/** Prepare the frame for calculating a new partition. */
@@ -805,7 +811,7 @@ private[execution] final class UnboundedPrecedingWindowFunctionFrame(
*
* This is a very expensive operator to use, O(n * (n - 1) /2), because we need to maintain a
* buffer and must do full recalculation after each row. Reverse iteration would be possible, if
- * the communitativity of the used window functions can be guaranteed.
+ * the commutativity of the used window functions can be guaranteed.
*
* @param target to write results to.
* @param processor to calculate the row values with.
@@ -819,8 +825,10 @@ private[execution] final class UnboundedFollowingWindowFunctionFrame(
/** Rows of the partition currently being processed. */
private[this] var input: RowBuffer = null
- /** Index of the first input row with a value equal to or greater than the lower bound of the
- * current output row. */
+ /**
+ * Index of the first input row with a value equal to or greater than the lower bound of the
+ * current output row.
+ */
private[this] var inputIndex = 0
/** Prepare the frame for calculating a new partition. */
@@ -874,7 +882,8 @@ private[execution] final class UnboundedFollowingWindowFunctionFrame(
* processor class.
*/
private[execution] object AggregateProcessor {
- def apply(functions: Array[Expression],
+ def apply(
+ functions: Array[Expression],
ordinal: Int,
inputAttributes: Seq[Attribute],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => () => MutableProjection):
@@ -885,11 +894,20 @@ private[execution] object AggregateProcessor {
val evaluateExpressions = mutable.Buffer.fill[Expression](ordinal)(NoOp)
val imperatives = mutable.Buffer.empty[ImperativeAggregate]
+ // SPARK-14244: `SizeBasedWindowFunction`s are firstly created on driver side and then
+ // serialized to executor side. These functions all reference a global singleton window
+ // partition size attribute reference, i.e., `SizeBasedWindowFunction.n`. Here we must collect
+ // the singleton instance created on driver side instead of using executor side
+ // `SizeBasedWindowFunction.n` to avoid binding failure caused by mismatching expression ID.
+ val partitionSize: Option[AttributeReference] = {
+ val aggs = functions.flatMap(_.collectFirst { case f: SizeBasedWindowFunction => f })
+ aggs.headOption.map(_.n)
+ }
+
// Check if there are any SizeBasedWindowFunctions. If there are, we add the partition size to
// the aggregation buffer. Note that the ordinal of the partition size value will always be 0.
- val trackPartitionSize = functions.exists(_.isInstanceOf[SizeBasedWindowFunction])
- if (trackPartitionSize) {
- aggBufferAttributes += SizeBasedWindowFunction.n
+ partitionSize.foreach { n =>
+ aggBufferAttributes += n
initialValues += NoOp
updateExpressions += NoOp
}
@@ -920,7 +938,7 @@ private[execution] object AggregateProcessor {
// Create the projections.
val initialProjection = newMutableProjection(
initialValues,
- Seq(SizeBasedWindowFunction.n))()
+ partitionSize.toSeq)()
val updateProjection = newMutableProjection(
updateExpressions,
aggBufferAttributes ++ inputAttributes)()
@@ -935,7 +953,7 @@ private[execution] object AggregateProcessor {
updateProjection,
evaluateProjection,
imperatives.toArray,
- trackPartitionSize)
+ partitionSize.isDefined)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
index 15627a7004..042c731901 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -47,17 +47,17 @@ abstract class AggregationIterator(
///////////////////////////////////////////////////////////////////////////
/**
- * The following combinations of AggregationMode are supported:
- * - Partial
- * - PartialMerge (for single distinct)
- * - Partial and PartialMerge (for single distinct)
- * - Final
- * - Complete (for SortBasedAggregate with functions that does not support Partial)
- * - Final and Complete (currently not used)
- *
- * TODO: AggregateMode should have only two modes: Update and Merge, AggregateExpression
- * could have a flag to tell it's final or not.
- */
+ * The following combinations of AggregationMode are supported:
+ * - Partial
+ * - PartialMerge (for single distinct)
+ * - Partial and PartialMerge (for single distinct)
+ * - Final
+ * - Complete (for SortBasedAggregate with functions that does not support Partial)
+ * - Final and Complete (currently not used)
+ *
+ * TODO: AggregateMode should have only two modes: Update and Merge, AggregateExpression
+ * could have a flag to tell it's final or not.
+ */
{
val modes = aggregateExpressions.map(_.mode).distinct.toSet
require(modes.size <= 2,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala
new file mode 100644
index 0000000000..e415dd8e6a
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
+import org.apache.spark.sql.types.StructType
+
+/**
+ * This is a helper object to generate an append-only single-key/single value aggregate hash
+ * map that can act as a 'cache' for extremely fast key-value lookups while evaluating aggregates
+ * (and fall back to the `BytesToBytesMap` if a given key isn't found). This is 'codegened' in
+ * TungstenAggregate to speed up aggregates w/ key.
+ *
+ * It is backed by a power-of-2-sized array for index lookups and a columnar batch that stores the
+ * key-value pairs. The index lookups in the array rely on linear probing (with a small number of
+ * maximum tries) and use an inexpensive hash function which makes it really efficient for a
+ * majority of lookups. However, using linear probing and an inexpensive hash function also makes it
+ * less robust as compared to the `BytesToBytesMap` (especially for a large number of keys or even
+ * for certain distribution of keys) and requires us to fall back on the latter for correctness.
+ */
+class ColumnarAggMapCodeGenerator(
+ ctx: CodegenContext,
+ generatedClassName: String,
+ groupingKeySchema: StructType,
+ bufferSchema: StructType) {
+ val groupingKeys = groupingKeySchema.map(k => (k.dataType.typeName, ctx.freshName("key")))
+ val bufferValues = bufferSchema.map(k => (k.dataType.typeName, ctx.freshName("value")))
+ val groupingKeySignature = groupingKeys.map(_.productIterator.toList.mkString(" ")).mkString(", ")
+
+ def generate(): String = {
+ s"""
+ |public class $generatedClassName {
+ |${initializeAggregateHashMap()}
+ |
+ |${generateFindOrInsert()}
+ |
+ |${generateEquals()}
+ |
+ |${generateHashFunction()}
+ |}
+ """.stripMargin
+ }
+
+ private def initializeAggregateHashMap(): String = {
+ val generatedSchema: String =
+ s"""
+ |new org.apache.spark.sql.types.StructType()
+ |${(groupingKeySchema ++ bufferSchema).map(key =>
+ s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""")
+ .mkString("\n")};
+ """.stripMargin
+
+ s"""
+ | private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch;
+ | private int[] buckets;
+ | private int numBuckets;
+ | private int maxSteps;
+ | private int numRows = 0;
+ | private org.apache.spark.sql.types.StructType schema = $generatedSchema
+ |
+ | public $generatedClassName(int capacity, double loadFactor, int maxSteps) {
+ | assert (capacity > 0 && ((capacity & (capacity - 1)) == 0));
+ | this.maxSteps = maxSteps;
+ | numBuckets = (int) (capacity / loadFactor);
+ | batch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate(schema,
+ | org.apache.spark.memory.MemoryMode.ON_HEAP, capacity);
+ | buckets = new int[numBuckets];
+ | java.util.Arrays.fill(buckets, -1);
+ | }
+ |
+ | public $generatedClassName() {
+ | new $generatedClassName(1 << 16, 0.25, 5);
+ | }
+ """.stripMargin
+ }
+
+ /**
+ * Generates a method that computes a hash by currently xor-ing all individual group-by keys. For
+ * instance, if we have 2 long group-by keys, the generated function would be of the form:
+ *
+ * {{{
+ * private long hash(long agg_key, long agg_key1) {
+ * return agg_key ^ agg_key1;
+ * }
+ * }}}
+ */
+ private def generateHashFunction(): String = {
+ s"""
+ |// TODO: Improve this hash function
+ |private long hash($groupingKeySignature) {
+ | return ${groupingKeys.map(_._2).mkString(" ^ ")};
+ |}
+ """.stripMargin
+ }
+
+ /**
+ * Generates a method that returns true if the group-by keys exist at a given index in the
+ * associated [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we
+ * have 2 long group-by keys, the generated function would be of the form:
+ *
+ * {{{
+ * private boolean equals(int idx, long agg_key, long agg_key1) {
+ * return batch.column(0).getLong(buckets[idx]) == agg_key &&
+ * batch.column(1).getLong(buckets[idx]) == agg_key1;
+ * }
+ * }}}
+ */
+ private def generateEquals(): String = {
+ s"""
+ |private boolean equals(int idx, $groupingKeySignature) {
+ | return ${groupingKeys.zipWithIndex.map(k =>
+ s"batch.column(${k._2}).getLong(buckets[idx]) == ${k._1._2}").mkString(" && ")};
+ |}
+ """.stripMargin
+ }
+
+ /**
+ * Generates a method that returns a mutable
+ * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row]] which keeps track of the
+ * aggregate value(s) for a given set of keys. If the corresponding row doesn't exist, the
+ * generated method adds the corresponding row in the associated
+ * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we
+ * have 2 long group-by keys, the generated function would be of the form:
+ *
+ * {{{
+ * public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(
+ * long agg_key, long agg_key1) {
+ * long h = hash(agg_key, agg_key1);
+ * int step = 0;
+ * int idx = (int) h & (numBuckets - 1);
+ * while (step < maxSteps) {
+ * // Return bucket index if it's either an empty slot or already contains the key
+ * if (buckets[idx] == -1) {
+ * batch.column(0).putLong(numRows, agg_key);
+ * batch.column(1).putLong(numRows, agg_key1);
+ * batch.column(2).putLong(numRows, 0);
+ * buckets[idx] = numRows++;
+ * return batch.getRow(buckets[idx]);
+ * } else if (equals(idx, agg_key, agg_key1)) {
+ * return batch.getRow(buckets[idx]);
+ * }
+ * idx = (idx + 1) & (numBuckets - 1);
+ * step++;
+ * }
+ * // Didn't find it
+ * return null;
+ * }
+ * }}}
+ */
+ private def generateFindOrInsert(): String = {
+ s"""
+ |public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(${
+ groupingKeySignature}) {
+ | long h = hash(${groupingKeys.map(_._2).mkString(", ")});
+ | int step = 0;
+ | int idx = (int) h & (numBuckets - 1);
+ | while (step < maxSteps) {
+ | // Return bucket index if it's either an empty slot or already contains the key
+ | if (buckets[idx] == -1) {
+ | ${groupingKeys.zipWithIndex.map(k =>
+ s"batch.column(${k._2}).putLong(numRows, ${k._1._2});").mkString("\n")}
+ | ${bufferValues.zipWithIndex.map(k =>
+ s"batch.column(${groupingKeys.length + k._2}).putLong(numRows, 0);")
+ .mkString("\n")}
+ | buckets[idx] = numRows++;
+ | return batch.getRow(buckets[idx]);
+ | } else if (equals(idx, ${groupingKeys.map(_._2).mkString(", ")})) {
+ | return batch.getRow(buckets[idx]);
+ | }
+ | idx = (idx + 1) & (numBuckets - 1);
+ | step++;
+ | }
+ | // Didn't find it
+ | return null;
+ |}
+ """.stripMargin
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
index 8f974980bb..de1491d357 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
@@ -46,9 +46,9 @@ class SortBasedAggregationIterator(
newMutableProjection) {
/**
- * Creates a new aggregation buffer and initializes buffer values
- * for all aggregate functions.
- */
+ * Creates a new aggregation buffer and initializes buffer values
+ * for all aggregate functions.
+ */
private def newBuffer: MutableRow = {
val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes)
val bufferRowSize: Int = bufferSchema.length
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 7c215d1b96..253592028c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{LongType, StructType}
import org.apache.spark.unsafe.KVIterator
case class TungstenAggregate(
@@ -64,8 +64,8 @@ case class TungstenAggregate(
override def requiredChildDistribution: List[Distribution] = {
requiredChildDistributionExpressions match {
- case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
- case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
+ case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
+ case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
case None => UnspecifiedDistribution :: Nil
}
}
@@ -266,8 +266,8 @@ case class TungstenAggregate(
private var sorterTerm: String = _
/**
- * This is called by generated Java class, should be public.
- */
+ * This is called by generated Java class, should be public.
+ */
def createHashMap(): UnsafeFixedWidthAggregationMap = {
// create initialized aggregate buffer
val initExpr = declFunctions.flatMap(f => f.initialValues)
@@ -286,15 +286,15 @@ case class TungstenAggregate(
}
/**
- * This is called by generated Java class, should be public.
- */
+ * This is called by generated Java class, should be public.
+ */
def createUnsafeJoiner(): UnsafeRowJoiner = {
GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
}
/**
- * Called by generated Java class to finish the aggregate and return a KVIterator.
- */
+ * Called by generated Java class to finish the aggregate and return a KVIterator.
+ */
def finishAggregate(
hashMap: UnsafeFixedWidthAggregationMap,
sorter: UnsafeKVExternalSorter): KVIterator[UnsafeRow, UnsafeRow] = {
@@ -372,8 +372,8 @@ case class TungstenAggregate(
}
/**
- * Generate the code for output.
- */
+ * Generate the code for output.
+ */
private def generateResultCode(
ctx: CodegenContext,
keyTerm: String,
@@ -437,11 +437,24 @@ case class TungstenAggregate(
val initAgg = ctx.freshName("initAgg")
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
+ // create AggregateHashMap
+ val isAggregateHashMapEnabled: Boolean = false
+ val isAggregateHashMapSupported: Boolean =
+ (groupingKeySchema ++ bufferSchema).forall(_.dataType == LongType)
+ val aggregateHashMapTerm = ctx.freshName("aggregateHashMap")
+ val aggregateHashMapClassName = ctx.freshName("GeneratedAggregateHashMap")
+ val aggregateHashMapGenerator = new ColumnarAggMapCodeGenerator(ctx, aggregateHashMapClassName,
+ groupingKeySchema, bufferSchema)
+ if (isAggregateHashMapEnabled && isAggregateHashMapSupported) {
+ ctx.addMutableState(aggregateHashMapClassName, aggregateHashMapTerm,
+ s"$aggregateHashMapTerm = new $aggregateHashMapClassName();")
+ }
+
// create hashMap
val thisPlan = ctx.addReferenceObj("plan", this)
hashMapTerm = ctx.freshName("hashMap")
val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
- ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();")
+ ctx.addMutableState(hashMapClassName, hashMapTerm, "")
sorterTerm = ctx.freshName("sorter")
ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "")
@@ -452,7 +465,9 @@ case class TungstenAggregate(
val doAgg = ctx.freshName("doAggregateWithKeys")
ctx.addNewFunction(doAgg,
s"""
+ ${if (isAggregateHashMapSupported) aggregateHashMapGenerator.generate() else ""}
private void $doAgg() throws java.io.IOException {
+ $hashMapTerm = $thisPlan.createHashMap();
${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
$iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm);
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 213bca907b..ce504e20e6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -242,9 +242,9 @@ class TungstenAggregationIterator(
// Basically the value of the KVIterator returned by externalSorter
// will be just aggregation buffer, so we rewrite the aggregateExpressions to reflect it.
val newExpressions = aggregateExpressions.map {
- case agg @ AggregateExpression(_, Partial, _) =>
+ case agg @ AggregateExpression(_, Partial, _, _) =>
agg.copy(mode = PartialMerge)
- case agg @ AggregateExpression(_, Complete, _) =>
+ case agg @ AggregateExpression(_, Complete, _, _) =>
agg.copy(mode = Final)
case other => other
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
new file mode 100644
index 0000000000..c39a78da6f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.api.java.function.MapFunction
+import org.apache.spark.sql.{Encoder, TypedColumn}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.expressions.Aggregator
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+// This file defines internal implementations for aggregators.
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+
+class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double] {
+ override def zero: Double = 0.0
+ override def reduce(b: Double, a: IN): Double = b + f(a)
+ override def merge(b1: Double, b2: Double): Double = b1 + b2
+ override def finish(reduction: Double): Double = reduction
+
+ override def bufferEncoder: Encoder[Double] = ExpressionEncoder[Double]()
+ override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]()
+
+ // Java api support
+ def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double])
+
+ def toColumnJava: TypedColumn[IN, java.lang.Double] = {
+ toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]]
+ }
+}
+
+
+class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] {
+ override def zero: Long = 0L
+ override def reduce(b: Long, a: IN): Long = b + f(a)
+ override def merge(b1: Long, b2: Long): Long = b1 + b2
+ override def finish(reduction: Long): Long = reduction
+
+ override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]()
+ override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]()
+
+ // Java api support
+ def this(f: MapFunction[IN, java.lang.Long]) = this(x => f.call(x).asInstanceOf[Long])
+
+ def toColumnJava: TypedColumn[IN, java.lang.Long] = {
+ toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]]
+ }
+}
+
+
+class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] {
+ override def zero: Long = 0
+ override def reduce(b: Long, a: IN): Long = {
+ if (f(a) == null) b else b + 1
+ }
+ override def merge(b1: Long, b2: Long): Long = b1 + b2
+ override def finish(reduction: Long): Long = reduction
+
+ override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]()
+ override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]()
+
+ // Java api support
+ def this(f: MapFunction[IN, Object]) = this(x => f.call(x))
+ def toColumnJava: TypedColumn[IN, java.lang.Long] = {
+ toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]]
+ }
+}
+
+
+class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double, Long), Double] {
+ override def zero: (Double, Long) = (0.0, 0L)
+ override def reduce(b: (Double, Long), a: IN): (Double, Long) = (f(a) + b._1, 1 + b._2)
+ override def finish(reduction: (Double, Long)): Double = reduction._1 / reduction._2
+ override def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) = {
+ (b1._1 + b2._1, b1._2 + b2._2)
+ }
+
+ override def bufferEncoder: Encoder[(Double, Long)] = ExpressionEncoder[(Double, Long)]()
+ override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]()
+
+ // Java api support
+ def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double])
+ def toColumnJava: TypedColumn[IN, java.lang.Double] = {
+ toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]]
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 1e113ccd4e..4682949fa1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.streaming.{StateStoreRestore, StateStoreSave}
/**
* Utility functions used by the query planner to convert our plan to new aggregation code path.
@@ -29,15 +30,11 @@ object Utils {
def planAggregateWithoutPartial(
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
- aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute],
resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = {
val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete))
- val completeAggregateAttributes = completeAggregateExpressions.map {
- expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
- }
-
+ val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute)
SortBasedAggregate(
requiredChildDistributionExpressions = Some(groupingExpressions),
groupingExpressions = groupingExpressions,
@@ -83,7 +80,6 @@ object Utils {
def planAggregateWithoutDistinct(
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
- aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute],
resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = {
// Check if we can use TungstenAggregate.
@@ -111,9 +107,7 @@ object Utils {
val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final))
// The attributes of the final aggregation buffer, which is presented as input to the result
// projection:
- val finalAggregateAttributes = finalAggregateExpressions.map {
- expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
- }
+ val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute)
val finalAggregate = createAggregate(
requiredChildDistributionExpressions = Some(groupingAttributes),
@@ -131,7 +125,6 @@ object Utils {
groupingExpressions: Seq[NamedExpression],
functionsWithDistinct: Seq[AggregateExpression],
functionsWithoutDistinct: Seq[AggregateExpression],
- aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute],
resultExpressions: Seq[NamedExpression],
child: SparkPlan): Seq[SparkPlan] = {
@@ -151,9 +144,7 @@ object Utils {
// 1. Create an Aggregate Operator for partial aggregations.
val partialAggregate: SparkPlan = {
val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
- val aggregateAttributes = aggregateExpressions.map {
- expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
- }
+ val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
// We will group by the original grouping expression, plus an additional expression for the
// DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping
// expressions will be [key, value].
@@ -169,9 +160,7 @@ object Utils {
// 2. Create an Aggregate Operator for partial merge aggregations.
val partialMergeAggregate: SparkPlan = {
val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
- val aggregateAttributes = aggregateExpressions.map {
- expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
- }
+ val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
createAggregate(
requiredChildDistributionExpressions =
Some(groupingAttributes ++ distinctAttributes),
@@ -190,7 +179,7 @@ object Utils {
// Children of an AggregateFunction with DISTINCT keyword has already
// been evaluated. At here, we need to replace original children
// to AttributeReferences.
- case agg @ AggregateExpression(aggregateFunction, mode, true) =>
+ case agg @ AggregateExpression(aggregateFunction, mode, true, _) =>
aggregateFunction.transformDown(distinctColumnAttributeLookup)
.asInstanceOf[AggregateFunction]
}
@@ -199,9 +188,7 @@ object Utils {
val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
// The attributes of the final aggregation buffer, which is presented as input to the result
// projection:
- val mergeAggregateAttributes = mergeAggregateExpressions.map {
- expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
- }
+ val mergeAggregateAttributes = mergeAggregateExpressions.map(_.resultAttribute)
val (distinctAggregateExpressions, distinctAggregateAttributes) =
rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) =>
// We rewrite the aggregate function to a non-distinct aggregation because
@@ -211,7 +198,7 @@ object Utils {
val expr = AggregateExpression(func, Partial, isDistinct = true)
// Use original AggregationFunction to lookup attributes, which is used to build
// aggregateFunctionToAttribute
- val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true)
+ val attr = functionsWithDistinct(i).resultAttribute
(expr, attr)
}.unzip
@@ -232,9 +219,7 @@ object Utils {
val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final))
// The attributes of the final aggregation buffer, which is presented as input to the result
// projection:
- val finalAggregateAttributes = finalAggregateExpressions.map {
- expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
- }
+ val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute)
val (distinctAggregateExpressions, distinctAggregateAttributes) =
rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) =>
@@ -245,7 +230,7 @@ object Utils {
val expr = AggregateExpression(func, Final, isDistinct = true)
// Use original AggregationFunction to lookup attributes, which is used to build
// aggregateFunctionToAttribute
- val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true)
+ val attr = functionsWithDistinct(i).resultAttribute
(expr, attr)
}.unzip
@@ -261,4 +246,90 @@ object Utils {
finalAndCompleteAggregate :: Nil
}
+
+ /**
+ * Plans a streaming aggregation using the following progression:
+ * - Partial Aggregation
+ * - Shuffle
+ * - Partial Merge (now there is at most 1 tuple per group)
+ * - StateStoreRestore (now there is 1 tuple from this batch + optionally one from the previous)
+ * - PartialMerge (now there is at most 1 tuple per group)
+ * - StateStoreSave (saves the tuple for the next batch)
+ * - Complete (output the current result of the aggregation)
+ */
+ def planStreamingAggregation(
+ groupingExpressions: Seq[NamedExpression],
+ functionsWithoutDistinct: Seq[AggregateExpression],
+ resultExpressions: Seq[NamedExpression],
+ child: SparkPlan): Seq[SparkPlan] = {
+
+ val groupingAttributes = groupingExpressions.map(_.toAttribute)
+
+ val partialAggregate: SparkPlan = {
+ val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
+ val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
+ // We will group by the original grouping expression, plus an additional expression for the
+ // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping
+ // expressions will be [key, value].
+ createAggregate(
+ groupingExpressions = groupingExpressions,
+ aggregateExpressions = aggregateExpressions,
+ aggregateAttributes = aggregateAttributes,
+ resultExpressions = groupingAttributes ++
+ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
+ child = child)
+ }
+
+ val partialMerged1: SparkPlan = {
+ val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
+ val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
+ createAggregate(
+ requiredChildDistributionExpressions =
+ Some(groupingAttributes),
+ groupingExpressions = groupingAttributes,
+ aggregateExpressions = aggregateExpressions,
+ aggregateAttributes = aggregateAttributes,
+ initialInputBufferOffset = groupingAttributes.length,
+ resultExpressions = groupingAttributes ++
+ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
+ child = partialAggregate)
+ }
+
+ val restored = StateStoreRestore(groupingAttributes, None, partialMerged1)
+
+ val partialMerged2: SparkPlan = {
+ val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
+ val aggregateAttributes = aggregateExpressions.map(_.resultAttribute)
+ createAggregate(
+ requiredChildDistributionExpressions =
+ Some(groupingAttributes),
+ groupingExpressions = groupingAttributes,
+ aggregateExpressions = aggregateExpressions,
+ aggregateAttributes = aggregateAttributes,
+ initialInputBufferOffset = groupingAttributes.length,
+ resultExpressions = groupingAttributes ++
+ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes),
+ child = restored)
+ }
+
+ val saved = StateStoreSave(groupingAttributes, None, partialMerged2)
+
+ val finalAndCompleteAggregate: SparkPlan = {
+ val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final))
+ // The attributes of the final aggregation buffer, which is presented as input to the result
+ // projection:
+ val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute)
+
+ createAggregate(
+ requiredChildDistributionExpressions = Some(groupingAttributes),
+ groupingExpressions = groupingAttributes,
+ aggregateExpressions = finalAggregateExpressions,
+ aggregateAttributes = finalAggregateAttributes,
+ initialInputBufferOffset = groupingAttributes.length,
+ resultExpressions = resultExpressions,
+ child = saved)
+ }
+
+ finalAndCompleteAggregate :: Nil
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 70e04d022f..344aaff348 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -20,11 +20,11 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics}
+import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.LongType
-import org.apache.spark.util.random.PoissonSampler
+import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
extends UnaryNode with CodegenSupport {
@@ -79,16 +79,20 @@ case class Filter(condition: Expression, child: SparkPlan)
// Split out all the IsNotNulls from condition.
private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
- case IsNotNull(a) if child.output.contains(a) => true
+ case IsNotNull(a: NullIntolerant) if a.references.subsetOf(child.outputSet) => true
case _ => false
}
// The columns that will filtered out by `IsNotNull` could be considered as not nullable.
- private val notNullAttributes = notNullPreds.flatMap(_.references)
+ private val notNullAttributes = notNullPreds.flatMap(_.references).distinct.map(_.exprId)
+
+ // Mark this as empty. We'll evaluate the input during doConsume(). We don't want to evaluate
+ // all the variables at the beginning to take advantage of short circuiting.
+ override def usedInputs: AttributeSet = AttributeSet.empty
override def output: Seq[Attribute] = {
child.output.map { a =>
- if (a.nullable && notNullAttributes.contains(a)) {
+ if (a.nullable && notNullAttributes.contains(a.exprId)) {
a.withNullability(false)
} else {
a
@@ -110,39 +114,80 @@ case class Filter(condition: Expression, child: SparkPlan)
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
val numOutput = metricTerm(ctx, "numOutputRows")
- // filter out the nulls
- val filterOutNull = notNullAttributes.map { a =>
- val idx = child.output.indexOf(a)
- s"if (${input(idx).isNull}) continue;"
- }.mkString("\n")
+ /**
+ * Generates code for `c`, using `in` for input attributes and `attrs` for nullability.
+ */
+ def genPredicate(c: Expression, in: Seq[ExprCode], attrs: Seq[Attribute]): String = {
+ val bound = BindReferences.bindReference(c, attrs)
+ val evaluated = evaluateRequiredVariables(child.output, in, c.references)
- ctx.currentVars = input
- val predicates = otherPreds.map { e =>
- val bound = ExpressionCanonicalizer.execute(
- BindReferences.bindReference(e, output))
- val ev = bound.gen(ctx)
+ // Generate the code for the predicate.
+ val ev = ExpressionCanonicalizer.execute(bound).gen(ctx)
val nullCheck = if (bound.nullable) {
s"${ev.isNull} || "
} else {
s""
}
+
s"""
+ |$evaluated
|${ev.code}
|if (${nullCheck}!${ev.value}) continue;
""".stripMargin
+ }
+
+ ctx.currentVars = input
+
+ // To generate the predicates we will follow this algorithm.
+ // For each predicate that is not IsNotNull, we will generate them one by one loading attributes
+ // as necessary. For each of both attributes, if there is a IsNotNull predicate we will generate
+ // that check *before* the predicate. After all of these predicates, we will generate the
+ // remaining IsNotNull checks that were not part of other predicates.
+ // This has the property of not doing redundant IsNotNull checks and taking better advantage of
+ // short-circuiting, not loading attributes until they are needed.
+ // This is very perf sensitive.
+ // TODO: revisit this. We can consider reordering predicates as well.
+ val generatedIsNotNullChecks = new Array[Boolean](notNullPreds.length)
+ val generated = otherPreds.map { c =>
+ val nullChecks = c.references.map { r =>
+ val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)}
+ if (idx != -1 && !generatedIsNotNullChecks(idx)) {
+ generatedIsNotNullChecks(idx) = true
+ // Use the child's output. The nullability is what the child produced.
+ genPredicate(notNullPreds(idx), input, child.output)
+ } else {
+ ""
+ }
+ }.mkString("\n").trim
+
+ // Here we use *this* operator's output with this output's nullability since we already
+ // enforced them with the IsNotNull checks above.
+ s"""
+ |$nullChecks
+ |${genPredicate(c, input, output)}
+ """.stripMargin.trim
+ }.mkString("\n")
+
+ val nullChecks = notNullPreds.zipWithIndex.map { case (c, idx) =>
+ if (!generatedIsNotNullChecks(idx)) {
+ genPredicate(c, input, child.output)
+ } else {
+ ""
+ }
}.mkString("\n")
// Reset the isNull to false for the not-null columns, then the followed operators could
// generate better code (remove dead branches).
val resultVars = input.zipWithIndex.map { case (ev, i) =>
- if (notNullAttributes.contains(child.output(i))) {
+ if (notNullAttributes.contains(child.output(i).exprId)) {
ev.isNull = "false"
}
ev
}
+
s"""
- |$filterOutNull
- |$predicates
+ |$generated
+ |$nullChecks
|$numOutput.add(1);
|${consume(ctx, resultVars)}
""".stripMargin
@@ -178,9 +223,12 @@ case class Sample(
upperBound: Double,
withReplacement: Boolean,
seed: Long,
- child: SparkPlan) extends UnaryNode {
+ child: SparkPlan) extends UnaryNode with CodegenSupport {
override def output: Seq[Attribute] = child.output
+ private[sql] override lazy val metrics = Map(
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
protected override def doExecute(): RDD[InternalRow] = {
if (withReplacement) {
// Disable gap sampling since the gap sampling method buffers two rows internally,
@@ -194,6 +242,63 @@ case class Sample(
child.execute().randomSampleWithRange(lowerBound, upperBound, seed)
}
}
+
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].upstreams()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
+ child.asInstanceOf[CodegenSupport].produce(ctx, this)
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
+ val numOutput = metricTerm(ctx, "numOutputRows")
+ val sampler = ctx.freshName("sampler")
+
+ if (withReplacement) {
+ val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName
+ val initSampler = ctx.freshName("initSampler")
+ ctx.addMutableState(s"$samplerClass<UnsafeRow>", sampler,
+ s"$initSampler();")
+
+ ctx.addNewFunction(initSampler,
+ s"""
+ | private void $initSampler() {
+ | $sampler = new $samplerClass<UnsafeRow>($upperBound - $lowerBound, false);
+ | java.util.Random random = new java.util.Random(${seed}L);
+ | long randomSeed = random.nextLong();
+ | int loopCount = 0;
+ | while (loopCount < partitionIndex) {
+ | randomSeed = random.nextLong();
+ | loopCount += 1;
+ | }
+ | $sampler.setSeed(randomSeed);
+ | }
+ """.stripMargin.trim)
+
+ val samplingCount = ctx.freshName("samplingCount")
+ s"""
+ | int $samplingCount = $sampler.sample();
+ | while ($samplingCount-- > 0) {
+ | $numOutput.add(1);
+ | ${consume(ctx, input)}
+ | }
+ """.stripMargin.trim
+ } else {
+ val samplerClass = classOf[BernoulliCellSampler[UnsafeRow]].getName
+ ctx.addMutableState(s"$samplerClass<UnsafeRow>", sampler,
+ s"""
+ | $sampler = new $samplerClass<UnsafeRow>($lowerBound, $upperBound, false);
+ | $sampler.setSeed(${seed}L + partitionIndex);
+ """.stripMargin.trim)
+
+ s"""
+ | if ($sampler.sample() == 0) continue;
+ | $numOutput.add(1);
+ | ${consume(ctx, input)}
+ """.stripMargin.trim
+ }
+ }
}
case class Range(
@@ -275,11 +380,7 @@ case class Range(
| // initialize Range
| if (!$initTerm) {
| $initTerm = true;
- | if ($input.hasNext()) {
- | initRange(((InternalRow) $input.next()).getInt(0));
- | } else {
- | return;
- | }
+ | initRange(partitionIndex);
| }
|
| while (!$overflow && $checkEnd) {
@@ -299,7 +400,7 @@ case class Range(
sqlContext
.sparkContext
.parallelize(0 until numSlices, numSlices)
- .mapPartitionsWithIndex((i, _) => {
+ .mapPartitionsWithIndex { (i, _) =>
val partitionStart = (i * numElements) / numSlices * step + start
val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start
def getSafeMargin(bi: BigInt): Long =
@@ -343,7 +444,7 @@ case class Range(
unsafeRow
}
}
- })
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
index 78664baa56..7cde04b626 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
@@ -38,7 +38,7 @@ private[columnar] trait ColumnAccessor {
def hasNext: Boolean
- def extractTo(row: MutableRow, ordinal: Int)
+ def extractTo(row: MutableRow, ordinal: Int): Unit
protected def underlyingBuffer: ByteBuffer
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala
index 7e26f19bb7..d30655e0c4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala
@@ -28,12 +28,12 @@ private[columnar] trait ColumnBuilder {
/**
* Initializes with an approximate lower bound on the expected number of elements in this column.
*/
- def initialize(initialSize: Int, columnName: String = "", useCompression: Boolean = false)
+ def initialize(initialSize: Int, columnName: String = "", useCompression: Boolean = false): Unit
/**
* Appends `row(ordinal)` to the column builder.
*/
- def appendFrom(row: InternalRow, ordinal: Int)
+ def appendFrom(row: InternalRow, ordinal: Int): Unit
/**
* Column statistics information
@@ -185,7 +185,7 @@ private[columnar] object ColumnBuilder {
case udt: UserDefinedType[_] =>
return apply(udt.sqlType, initialSize, columnName, useCompression)
case other =>
- throw new Exception(s"not suppported type: $other")
+ throw new Exception(s"not supported type: $other")
}
builder.initialize(initialSize, columnName, useCompression)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
index d4e5db459f..e2e33e3246 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala
@@ -88,7 +88,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
case array: ArrayType => classOf[ArrayColumnAccessor].getName
case t: MapType => classOf[MapColumnAccessor].getName
}
- ctx.addMutableState(accessorCls, accessorName, s"$accessorName = null;")
+ ctx.addMutableState(accessorCls, accessorName, "")
val createCode = dt match {
case t if ctx.isPrimitiveType(dt) =>
@@ -114,6 +114,42 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
(createCode, extract + patch)
}.unzip
+ /*
+ * 200 = 6000 bytes / 30 (up to 30 bytes per one call))
+ * the maximum byte code size to be compiled for HotSpot is 8000.
+ * We should keep less than 8000
+ */
+ val numberOfStatementsThreshold = 200
+ val (initializerAccessorCalls, extractorCalls) =
+ if (initializeAccessors.length <= numberOfStatementsThreshold) {
+ (initializeAccessors.mkString("\n"), extractors.mkString("\n"))
+ } else {
+ val groupedAccessorsItr = initializeAccessors.grouped(numberOfStatementsThreshold)
+ val groupedExtractorsItr = extractors.grouped(numberOfStatementsThreshold)
+ var groupedAccessorsLength = 0
+ groupedAccessorsItr.zipWithIndex.map { case (body, i) =>
+ groupedAccessorsLength += 1
+ val funcName = s"accessors$i"
+ val funcCode = s"""
+ |private void $funcName() {
+ | ${body.mkString("\n")}
+ |}
+ """.stripMargin
+ ctx.addNewFunction(funcName, funcCode)
+ }
+ groupedExtractorsItr.zipWithIndex.map { case (body, i) =>
+ val funcName = s"extractors$i"
+ val funcCode = s"""
+ |private void $funcName() {
+ | ${body.mkString("\n")}
+ |}
+ """.stripMargin
+ ctx.addNewFunction(funcName, funcCode)
+ }
+ ((0 to groupedAccessorsLength - 1).map { i => s"accessors$i();" }.mkString("\n"),
+ (0 to groupedAccessorsLength - 1).map { i => s"extractors$i();" }.mkString("\n"))
+ }
+
val code = s"""
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
@@ -149,8 +185,6 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
this.nativeOrder = ByteOrder.nativeOrder();
this.buffers = new byte[${columnTypes.length}][];
this.mutableRow = new MutableUnsafeRow(rowWriter);
-
- ${ctx.initMutableStates()}
}
public void initialize(Iterator input, DataType[] columnTypes, int[] columnIndexes) {
@@ -159,6 +193,8 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
this.columnIndexes = columnIndexes;
}
+ ${ctx.declareAddedFunctions()}
+
public boolean hasNext() {
if (currentRow < numRowsInBatch) {
return true;
@@ -173,7 +209,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
for (int i = 0; i < columnIndexes.length; i ++) {
buffers[i] = batch.buffers()[columnIndexes[i]];
}
- ${initializeAccessors.mkString("\n")}
+ ${initializerAccessorCalls}
return hasNext();
}
@@ -182,7 +218,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
currentRow += 1;
bufferHolder.reset();
rowWriter.zeroOutNullBytes();
- ${extractors.mkString("\n")}
+ ${extractorCalls}
unsafeRow.setTotalSize(bufferHolder.totalSize());
return unsafeRow;
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AlterTableCommandParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AlterTableCommandParser.scala
deleted file mode 100644
index 9fbe6db467..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AlterTableCommandParser.scala
+++ /dev/null
@@ -1,431 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.command
-
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec
-import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, SortDirection}
-import org.apache.spark.sql.catalyst.parser._
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.datasources._
-import org.apache.spark.sql.types.StructType
-
-
-/**
- * Helper object to parse alter table commands.
- */
-object AlterTableCommandParser {
- import ParserUtils._
-
- /**
- * Parse the given node assuming it is an alter table command.
- */
- def parse(node: ASTNode): LogicalPlan = {
- node.children match {
- case (tabName @ Token("TOK_TABNAME", _)) :: otherNodes =>
- val tableIdent = extractTableIdent(tabName)
- val partSpec = getClauseOption("TOK_PARTSPEC", node.children).map(parsePartitionSpec)
- matchAlterTableCommands(node, otherNodes, tableIdent, partSpec)
- case _ =>
- parseFailed("Could not parse ALTER TABLE command", node)
- }
- }
-
- private def cleanAndUnquoteString(s: String): String = {
- cleanIdentifier(unquoteString(s))
- }
-
- /**
- * Extract partition spec from the given [[ASTNode]] as a map, assuming it exists.
- *
- * Example format:
- *
- * TOK_PARTSPEC
- * :- TOK_PARTVAL
- * : :- dt
- * : +- '2008-08-08'
- * +- TOK_PARTVAL
- * :- country
- * +- 'us'
- */
- private def parsePartitionSpec(node: ASTNode): Map[String, String] = {
- node match {
- case Token("TOK_PARTSPEC", partitions) =>
- partitions.map {
- // Note: sometimes there's a "=", "<" or ">" between the key and the value
- // (e.g. when dropping all partitions with value > than a certain constant)
- case Token("TOK_PARTVAL", ident :: conj :: constant :: Nil) =>
- (cleanAndUnquoteString(ident.text), cleanAndUnquoteString(constant.text))
- case Token("TOK_PARTVAL", ident :: constant :: Nil) =>
- (cleanAndUnquoteString(ident.text), cleanAndUnquoteString(constant.text))
- case Token("TOK_PARTVAL", ident :: Nil) =>
- (cleanAndUnquoteString(ident.text), null)
- case _ =>
- parseFailed("Invalid ALTER TABLE command", node)
- }.toMap
- case _ =>
- parseFailed("Expected partition spec in ALTER TABLE command", node)
- }
- }
-
- /**
- * Extract table properties from the given [[ASTNode]] as a map, assuming it exists.
- *
- * Example format:
- *
- * TOK_TABLEPROPERTIES
- * +- TOK_TABLEPROPLIST
- * :- TOK_TABLEPROPERTY
- * : :- 'test'
- * : +- 'value'
- * +- TOK_TABLEPROPERTY
- * :- 'comment'
- * +- 'new_comment'
- */
- private def extractTableProps(node: ASTNode): Map[String, String] = {
- node match {
- case Token("TOK_TABLEPROPERTIES", propsList) =>
- propsList.flatMap {
- case Token("TOK_TABLEPROPLIST", props) =>
- props.map { case Token("TOK_TABLEPROPERTY", key :: value :: Nil) =>
- val k = cleanAndUnquoteString(key.text)
- val v = value match {
- case Token("TOK_NULL", Nil) => null
- case _ => cleanAndUnquoteString(value.text)
- }
- (k, v)
- }
- case _ =>
- parseFailed("Invalid ALTER TABLE command", node)
- }.toMap
- case _ =>
- parseFailed("Expected table properties in ALTER TABLE command", node)
- }
- }
-
- /**
- * Parse an alter table command from a [[ASTNode]] into a [[LogicalPlan]].
- * This follows https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL.
- *
- * @param node the original [[ASTNode]] to parse.
- * @param otherNodes the other [[ASTNode]]s after the first one containing the table name.
- * @param tableIdent identifier of the table, parsed from the first [[ASTNode]].
- * @param partition spec identifying the partition this command is concerned with, if any.
- */
- // TODO: This method is massive. Break it down.
- private def matchAlterTableCommands(
- node: ASTNode,
- otherNodes: Seq[ASTNode],
- tableIdent: TableIdentifier,
- partition: Option[TablePartitionSpec]): LogicalPlan = {
- otherNodes match {
- // ALTER TABLE table_name RENAME TO new_table_name;
- case Token("TOK_ALTERTABLE_RENAME", renameArgs) :: _ =>
- val tableNameClause = getClause("TOK_TABNAME", renameArgs)
- val newTableIdent = extractTableIdent(tableNameClause)
- AlterTableRename(tableIdent, newTableIdent)(node.source)
-
- // ALTER TABLE table_name SET TBLPROPERTIES ('comment' = new_comment);
- case Token("TOK_ALTERTABLE_PROPERTIES", args) :: _ =>
- val properties = extractTableProps(args.head)
- AlterTableSetProperties(tableIdent, properties)(node.source)
-
- // ALTER TABLE table_name UNSET TBLPROPERTIES IF EXISTS ('comment', 'key');
- case Token("TOK_ALTERTABLE_DROPPROPERTIES", args) :: _ =>
- val properties = extractTableProps(args.head)
- val ifExists = getClauseOption("TOK_IFEXISTS", args).isDefined
- AlterTableUnsetProperties(tableIdent, properties, ifExists)(node.source)
-
- // ALTER TABLE table_name [PARTITION spec] SET SERDE serde_name [WITH SERDEPROPERTIES props];
- case Token("TOK_ALTERTABLE_SERIALIZER", Token(serdeClassName, Nil) :: serdeArgs) :: _ =>
- AlterTableSerDeProperties(
- tableIdent,
- Some(cleanAndUnquoteString(serdeClassName)),
- serdeArgs.headOption.map(extractTableProps),
- partition)(node.source)
-
- // ALTER TABLE table_name [PARTITION spec] SET SERDEPROPERTIES serde_properties;
- case Token("TOK_ALTERTABLE_SERDEPROPERTIES", args) :: _ =>
- AlterTableSerDeProperties(
- tableIdent,
- None,
- Some(extractTableProps(args.head)),
- partition)(node.source)
-
- // ALTER TABLE table_name CLUSTERED BY (col, ...) [SORTED BY (col, ...)] INTO n BUCKETS;
- case Token("TOK_ALTERTABLE_CLUSTER_SORT", Token("TOK_ALTERTABLE_BUCKETS", b) :: Nil) :: _ =>
- val clusterCols: Seq[String] = b.head match {
- case Token("TOK_TABCOLNAME", children) => children.map(_.text)
- case _ => parseFailed("Invalid ALTER TABLE command", node)
- }
- // If sort columns are specified, num buckets should be the third arg.
- // If sort columns are not specified, num buckets should be the second arg.
- // TODO: actually use `sortDirections` once we actually store that in the metastore
- val (sortCols: Seq[String], sortDirections: Seq[SortDirection], numBuckets: Int) = {
- b.tail match {
- case Token("TOK_TABCOLNAME", children) :: numBucketsNode :: Nil =>
- val (cols, directions) = children.map {
- case Token("TOK_TABSORTCOLNAMEASC", Token(col, Nil) :: Nil) => (col, Ascending)
- case Token("TOK_TABSORTCOLNAMEDESC", Token(col, Nil) :: Nil) => (col, Descending)
- }.unzip
- (cols, directions, numBucketsNode.text.toInt)
- case numBucketsNode :: Nil =>
- (Nil, Nil, numBucketsNode.text.toInt)
- case _ =>
- parseFailed("Invalid ALTER TABLE command", node)
- }
- }
- AlterTableStorageProperties(
- tableIdent,
- BucketSpec(numBuckets, clusterCols, sortCols))(node.source)
-
- // ALTER TABLE table_name NOT CLUSTERED
- case Token("TOK_ALTERTABLE_CLUSTER_SORT", Token("TOK_NOT_CLUSTERED", Nil) :: Nil) :: _ =>
- AlterTableNotClustered(tableIdent)(node.source)
-
- // ALTER TABLE table_name NOT SORTED
- case Token("TOK_ALTERTABLE_CLUSTER_SORT", Token("TOK_NOT_SORTED", Nil) :: Nil) :: _ =>
- AlterTableNotSorted(tableIdent)(node.source)
-
- // ALTER TABLE table_name SKEWED BY (col1, col2)
- // ON ((col1_value, col2_value) [, (col1_value, col2_value), ...])
- // [STORED AS DIRECTORIES];
- case Token("TOK_ALTERTABLE_SKEWED",
- Token("TOK_TABLESKEWED",
- Token("TOK_TABCOLNAME", colNames) :: colValues :: rest) :: Nil) :: _ =>
- // Example format:
- //
- // TOK_ALTERTABLE_SKEWED
- // :- TOK_TABLESKEWED
- // : :- TOK_TABCOLNAME
- // : : :- dt
- // : : +- country
- // :- TOK_TABCOLVALUE_PAIR
- // : :- TOK_TABCOLVALUES
- // : : :- TOK_TABCOLVALUE
- // : : : :- '2008-08-08'
- // : : : +- 'us'
- // : :- TOK_TABCOLVALUES
- // : : :- TOK_TABCOLVALUE
- // : : : :- '2009-09-09'
- // : : : +- 'uk'
- // +- TOK_STOREASDIR
- val names = colNames.map { n => cleanAndUnquoteString(n.text) }
- val values = colValues match {
- case Token("TOK_TABCOLVALUE", vals) =>
- Seq(vals.map { n => cleanAndUnquoteString(n.text) })
- case Token("TOK_TABCOLVALUE_PAIR", pairs) =>
- pairs.map {
- case Token("TOK_TABCOLVALUES", Token("TOK_TABCOLVALUE", vals) :: Nil) =>
- vals.map { n => cleanAndUnquoteString(n.text) }
- case _ =>
- parseFailed("Invalid ALTER TABLE command", node)
- }
- case _ =>
- parseFailed("Invalid ALTER TABLE command", node)
- }
- val storedAsDirs = rest match {
- case Token("TOK_STOREDASDIRS", Nil) :: Nil => true
- case _ => false
- }
- AlterTableSkewed(
- tableIdent,
- names,
- values,
- storedAsDirs)(node.source)
-
- // ALTER TABLE table_name NOT SKEWED
- case Token("TOK_ALTERTABLE_SKEWED", Nil) :: _ =>
- AlterTableNotSkewed(tableIdent)(node.source)
-
- // ALTER TABLE table_name NOT STORED AS DIRECTORIES
- case Token("TOK_ALTERTABLE_SKEWED", Token("TOK_STOREDASDIRS", Nil) :: Nil) :: _ =>
- AlterTableNotStoredAsDirs(tableIdent)(node.source)
-
- // ALTER TABLE table_name SET SKEWED LOCATION (col1="loc1" [, (col2, col3)="loc2", ...] );
- case Token("TOK_ALTERTABLE_SKEWED_LOCATION",
- Token("TOK_SKEWED_LOCATIONS",
- Token("TOK_SKEWED_LOCATION_LIST", locationMaps) :: Nil) :: Nil) :: _ =>
- // Example format:
- //
- // TOK_ALTERTABLE_SKEWED_LOCATION
- // +- TOK_SKEWED_LOCATIONS
- // +- TOK_SKEWED_LOCATION_LIST
- // :- TOK_SKEWED_LOCATION_MAP
- // : :- 'col1'
- // : +- 'loc1'
- // +- TOK_SKEWED_LOCATION_MAP
- // :- TOK_TABCOLVALUES
- // : +- TOK_TABCOLVALUE
- // : :- 'col2'
- // : +- 'col3'
- // +- 'loc2'
- val skewedMaps = locationMaps.flatMap {
- case Token("TOK_SKEWED_LOCATION_MAP", col :: loc :: Nil) =>
- col match {
- case Token(const, Nil) =>
- Seq((cleanAndUnquoteString(const), cleanAndUnquoteString(loc.text)))
- case Token("TOK_TABCOLVALUES", Token("TOK_TABCOLVALUE", keys) :: Nil) =>
- keys.map { k => (cleanAndUnquoteString(k.text), cleanAndUnquoteString(loc.text)) }
- }
- case _ =>
- parseFailed("Invalid ALTER TABLE command", node)
- }.toMap
- AlterTableSkewedLocation(tableIdent, skewedMaps)(node.source)
-
- // ALTER TABLE table_name ADD [IF NOT EXISTS] PARTITION spec [LOCATION 'loc1']
- // spec [LOCATION 'loc2'] ...;
- case Token("TOK_ALTERTABLE_ADDPARTS", args) :: _ =>
- val (ifNotExists, parts) = args.head match {
- case Token("TOK_IFNOTEXISTS", Nil) => (true, args.tail)
- case _ => (false, args)
- }
- // List of (spec, location) to describe partitions to add
- // Each partition spec may or may not be followed by a location
- val parsedParts = new ArrayBuffer[(TablePartitionSpec, Option[String])]
- parts.foreach {
- case t @ Token("TOK_PARTSPEC", _) =>
- parsedParts += ((parsePartitionSpec(t), None))
- case Token("TOK_PARTITIONLOCATION", loc :: Nil) =>
- // Update the location of the last partition we just added
- if (parsedParts.nonEmpty) {
- val (spec, _) = parsedParts.remove(parsedParts.length - 1)
- parsedParts += ((spec, Some(unquoteString(loc.text))))
- }
- case _ =>
- parseFailed("Invalid ALTER TABLE command", node)
- }
- AlterTableAddPartition(tableIdent, parsedParts, ifNotExists)(node.source)
-
- // ALTER TABLE table_name PARTITION spec1 RENAME TO PARTITION spec2;
- case Token("TOK_ALTERTABLE_RENAMEPART", spec :: Nil) :: _ =>
- val newPartition = parsePartitionSpec(spec)
- val oldPartition = partition.getOrElse {
- parseFailed("Expected old partition spec in ALTER TABLE rename partition command", node)
- }
- AlterTableRenamePartition(tableIdent, oldPartition, newPartition)(node.source)
-
- // ALTER TABLE table_name_1 EXCHANGE PARTITION spec WITH TABLE table_name_2;
- case Token("TOK_ALTERTABLE_EXCHANGEPARTITION", spec :: newTable :: Nil) :: _ =>
- val parsedSpec = parsePartitionSpec(spec)
- val newTableIdent = extractTableIdent(newTable)
- AlterTableExchangePartition(tableIdent, newTableIdent, parsedSpec)(node.source)
-
- // ALTER TABLE table_name DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE];
- case Token("TOK_ALTERTABLE_DROPPARTS", args) :: _ =>
- val parts = args.collect { case p @ Token("TOK_PARTSPEC", _) => parsePartitionSpec(p) }
- val ifExists = getClauseOption("TOK_IFEXISTS", args).isDefined
- val purge = getClauseOption("PURGE", args).isDefined
- AlterTableDropPartition(tableIdent, parts, ifExists, purge)(node.source)
-
- // ALTER TABLE table_name ARCHIVE PARTITION spec;
- case Token("TOK_ALTERTABLE_ARCHIVE", spec :: Nil) :: _ =>
- AlterTableArchivePartition(tableIdent, parsePartitionSpec(spec))(node.source)
-
- // ALTER TABLE table_name UNARCHIVE PARTITION spec;
- case Token("TOK_ALTERTABLE_UNARCHIVE", spec :: Nil) :: _ =>
- AlterTableUnarchivePartition(tableIdent, parsePartitionSpec(spec))(node.source)
-
- // ALTER TABLE table_name [PARTITION spec] SET FILEFORMAT file_format;
- case Token("TOK_ALTERTABLE_FILEFORMAT", args) :: _ =>
- val Seq(fileFormat, genericFormat) =
- getClauses(Seq("TOK_TABLEFILEFORMAT", "TOK_FILEFORMAT_GENERIC"), args)
- // Note: the AST doesn't contain information about which file format is being set here.
- // E.g. we can't differentiate between INPUTFORMAT and OUTPUTFORMAT if either is set.
- // Right now this just stores the values, but we should figure out how to get the keys.
- val fFormat = fileFormat
- .map { _.children.map { n => cleanAndUnquoteString(n.text) }}
- .getOrElse(Seq())
- val gFormat = genericFormat.map { f => cleanAndUnquoteString(f.children(0).text) }
- AlterTableSetFileFormat(tableIdent, partition, fFormat, gFormat)(node.source)
-
- // ALTER TABLE table_name [PARTITION spec] SET LOCATION "loc";
- case Token("TOK_ALTERTABLE_LOCATION", Token(loc, Nil) :: Nil) :: _ =>
- AlterTableSetLocation(tableIdent, partition, cleanAndUnquoteString(loc))(node.source)
-
- // ALTER TABLE table_name TOUCH [PARTITION spec];
- case Token("TOK_ALTERTABLE_TOUCH", args) :: _ =>
- // Note: the partition spec, if it exists, comes after TOUCH, so `partition` should
- // always be None here. Instead, we need to parse it from the TOUCH node's children.
- val part = getClauseOption("TOK_PARTSPEC", args).map(parsePartitionSpec)
- AlterTableTouch(tableIdent, part)(node.source)
-
- // ALTER TABLE table_name [PARTITION spec] COMPACT 'compaction_type';
- case Token("TOK_ALTERTABLE_COMPACT", Token(compactType, Nil) :: Nil) :: _ =>
- AlterTableCompact(tableIdent, partition, cleanAndUnquoteString(compactType))(node.source)
-
- // ALTER TABLE table_name [PARTITION spec] CONCATENATE;
- case Token("TOK_ALTERTABLE_MERGEFILES", _) :: _ =>
- AlterTableMerge(tableIdent, partition)(node.source)
-
- // ALTER TABLE table_name [PARTITION spec] CHANGE [COLUMN] col_old_name col_new_name
- // column_type [COMMENT col_comment] [FIRST|AFTER column_name] [CASCADE|RESTRICT];
- case Token("TOK_ALTERTABLE_RENAMECOL", oldName :: newName :: dataType :: args) :: _ =>
- val afterColName: Option[String] =
- getClauseOption("TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION", args).map { ap =>
- ap.children match {
- case Token(col, Nil) :: Nil => col
- case _ => parseFailed("Invalid ALTER TABLE command", node)
- }
- }
- val restrict = getClauseOption("TOK_RESTRICT", args).isDefined
- val cascade = getClauseOption("TOK_CASCADE", args).isDefined
- val comment = args.headOption.map {
- case Token("TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION", _) => null
- case Token("TOK_RESTRICT", _) => null
- case Token("TOK_CASCADE", _) => null
- case Token(commentStr, Nil) => cleanAndUnquoteString(commentStr)
- case _ => parseFailed("Invalid ALTER TABLE command", node)
- }
- AlterTableChangeCol(
- tableIdent,
- partition,
- oldName.text,
- newName.text,
- nodeToDataType(dataType),
- comment,
- afterColName,
- restrict,
- cascade)(node.source)
-
- // ALTER TABLE table_name [PARTITION spec] ADD COLUMNS (name type [COMMENT comment], ...)
- // [CASCADE|RESTRICT]
- case Token("TOK_ALTERTABLE_ADDCOLS", args) :: _ =>
- val columnNodes = getClause("TOK_TABCOLLIST", args).children
- val columns = StructType(columnNodes.map(nodeToStructField))
- val restrict = getClauseOption("TOK_RESTRICT", args).isDefined
- val cascade = getClauseOption("TOK_CASCADE", args).isDefined
- AlterTableAddCol(tableIdent, partition, columns, restrict, cascade)(node.source)
-
- // ALTER TABLE table_name [PARTITION spec] REPLACE COLUMNS (name type [COMMENT comment], ...)
- // [CASCADE|RESTRICT]
- case Token("TOK_ALTERTABLE_REPLACECOLS", args) :: _ =>
- val columnNodes = getClause("TOK_TABCOLLIST", args).children
- val columns = StructType(columnNodes.map(nodeToStructField))
- val restrict = getClauseOption("TOK_RESTRICT", args).isDefined
- val cascade = getClauseOption("TOK_CASCADE", args).isDefined
- AlterTableReplaceCol(tableIdent, partition, columns, restrict, cascade)(node.source)
-
- case _ =>
- parseFailed("Unsupported ALTER TABLE command", node)
- }
- }
-
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
index 964f0a7a7b..5d00c805a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
@@ -21,17 +21,17 @@ import java.util.NoSuchElementException
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Dataset, Row, SQLContext}
+import org.apache.spark.sql.{AnalysisException, Dataset, Row, SQLContext}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.debug._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
-
/**
* A logical command that is executed for its side-effects. `RunnableCommand`s are
* wrapped in `ExecutedCommand` during execution.
@@ -237,16 +237,23 @@ case class ExplainCommand(
logicalPlan: LogicalPlan,
override val output: Seq[Attribute] =
Seq(AttributeReference("plan", StringType, nullable = true)()),
- extended: Boolean = false)
+ extended: Boolean = false,
+ codegen: Boolean = false)
extends RunnableCommand {
// Run through the optimizer to generate the physical plan.
override def run(sqlContext: SQLContext): Seq[Row] = try {
// TODO in Hive, the "extended" ExplainCommand prints the AST as well, and detailed properties.
val queryExecution = sqlContext.executePlan(logicalPlan)
- val outputString = if (extended) queryExecution.toString else queryExecution.simpleString
-
- outputString.split("\n").map(Row(_))
+ val outputString =
+ if (codegen) {
+ codegenString(queryExecution.executedPlan)
+ } else if (extended) {
+ queryExecution.toString
+ } else {
+ queryExecution.simpleString
+ }
+ Seq(Row(outputString))
} catch { case cause: TreeNodeException[_] =>
("Error occurred during query planning: \n" + cause.getMessage).split("\n").map(Row(_))
}
@@ -322,18 +329,17 @@ case class DescribeCommand(
* If a databaseName is not given, the current database will be used.
* The syntax of using this command in SQL is:
* {{{
- * SHOW TABLES [IN databaseName]
+ * SHOW TABLES [(IN|FROM) database_name] [[LIKE] 'identifier_with_wildcards'];
* }}}
*/
-case class ShowTablesCommand(databaseName: Option[String]) extends RunnableCommand {
+case class ShowTablesCommand(
+ databaseName: Option[String],
+ tableIdentifierPattern: Option[String]) extends RunnableCommand {
// The result of SHOW TABLES has two columns, tableName and isTemporary.
override val output: Seq[Attribute] = {
- val schema = StructType(
- StructField("tableName", StringType, false) ::
- StructField("isTemporary", BooleanType, false) :: Nil)
-
- schema.toAttributes
+ AttributeReference("tableName", StringType, nullable = false)() ::
+ AttributeReference("isTemporary", BooleanType, nullable = false)() :: Nil
}
override def run(sqlContext: SQLContext): Seq[Row] = {
@@ -341,11 +347,78 @@ case class ShowTablesCommand(databaseName: Option[String]) extends RunnableComma
// instead of calling tables in sqlContext.
val catalog = sqlContext.sessionState.catalog
val db = databaseName.getOrElse(catalog.getCurrentDatabase)
- val rows = catalog.listTables(db).map { t =>
+ val tables =
+ tableIdentifierPattern.map(catalog.listTables(db, _)).getOrElse(catalog.listTables(db))
+ tables.map { t =>
val isTemp = t.database.isEmpty
Row(t.table, isTemp)
}
- rows
+ }
+}
+
+/**
+ * A command for users to list the databases/schemas.
+ * If a databasePattern is supplied then the databases that only matches the
+ * pattern would be listed.
+ * The syntax of using this command in SQL is:
+ * {{{
+ * SHOW (DATABASES|SCHEMAS) [LIKE 'identifier_with_wildcards'];
+ * }}}
+ */
+case class ShowDatabasesCommand(databasePattern: Option[String]) extends RunnableCommand {
+
+ // The result of SHOW DATABASES has one column called 'result'
+ override val output: Seq[Attribute] = {
+ AttributeReference("result", StringType, nullable = false)() :: Nil
+ }
+
+ override def run(sqlContext: SQLContext): Seq[Row] = {
+ val catalog = sqlContext.sessionState.catalog
+ val databases =
+ databasePattern.map(catalog.listDatabases(_)).getOrElse(catalog.listDatabases())
+ databases.map { d => Row(d) }
+ }
+}
+
+/**
+ * A command for users to list the properties for a table If propertyKey is specified, the value
+ * for the propertyKey is returned. If propertyKey is not specified, all the keys and their
+ * corresponding values are returned.
+ * The syntax of using this command in SQL is:
+ * {{{
+ * SHOW TBLPROPERTIES table_name[('propertyKey')];
+ * }}}
+ */
+case class ShowTablePropertiesCommand(
+ table: TableIdentifier,
+ propertyKey: Option[String]) extends RunnableCommand {
+
+ override val output: Seq[Attribute] = {
+ val schema = AttributeReference("value", StringType, nullable = false)() :: Nil
+ propertyKey match {
+ case None => AttributeReference("key", StringType, nullable = false)() :: schema
+ case _ => schema
+ }
+ }
+
+ override def run(sqlContext: SQLContext): Seq[Row] = {
+ val catalog = sqlContext.sessionState.catalog
+
+ if (catalog.isTemporaryTable(table)) {
+ Seq.empty[Row]
+ } else {
+ val catalogTable = sqlContext.sessionState.catalog.getTableMetadata(table)
+
+ propertyKey match {
+ case Some(p) =>
+ val propValue = catalogTable
+ .properties
+ .getOrElse(p, s"Table ${catalogTable.qualifiedName} does not have property: $p")
+ Seq(Row(propValue))
+ case None =>
+ catalogTable.properties.map(p => Row(p._1, p._2)).toSeq
+ }
+ }
}
}
@@ -353,8 +426,12 @@ case class ShowTablesCommand(databaseName: Option[String]) extends RunnableComma
* A command for users to list all of the registered functions.
* The syntax of using this command in SQL is:
* {{{
- * SHOW FUNCTIONS
+ * SHOW FUNCTIONS [LIKE pattern]
* }}}
+ * For the pattern, '*' matches any sequence of characters (including no characters) and
+ * '|' is for alternation.
+ * For example, "show functions like 'yea*|windo*'" will return "window" and "year".
+ *
* TODO currently we are simply ignore the db
*/
case class ShowFunctions(db: Option[String], pattern: Option[String]) extends RunnableCommand {
@@ -365,18 +442,17 @@ case class ShowFunctions(db: Option[String], pattern: Option[String]) extends Ru
schema.toAttributes
}
- override def run(sqlContext: SQLContext): Seq[Row] = pattern match {
- case Some(p) =>
- try {
- val regex = java.util.regex.Pattern.compile(p)
- sqlContext.sessionState.functionRegistry.listFunction()
- .filter(regex.matcher(_).matches()).map(Row(_))
- } catch {
- // probably will failed in the regex that user provided, then returns empty row.
- case _: Throwable => Seq.empty[Row]
- }
- case None =>
- sqlContext.sessionState.functionRegistry.listFunction().map(Row(_))
+ override def run(sqlContext: SQLContext): Seq[Row] = {
+ val dbName = db.getOrElse(sqlContext.sessionState.catalog.getCurrentDatabase)
+ // If pattern is not specified, we use '*', which is used to
+ // match any sequence of characters (including no characters).
+ val functionNames =
+ sqlContext.sessionState.catalog
+ .listFunctions(dbName, pattern.getOrElse("*"))
+ .map(_.unquotedString)
+ // The session catalog caches some persistent functions in the FunctionRegistry
+ // so there can be duplicates.
+ functionNames.distinct.sorted.map(Row(_))
}
}
@@ -407,20 +483,38 @@ case class DescribeFunction(
}
override def run(sqlContext: SQLContext): Seq[Row] = {
- sqlContext.sessionState.functionRegistry.lookupFunction(functionName) match {
- case Some(info) =>
- val result =
- Row(s"Function: ${info.getName}") ::
- Row(s"Class: ${info.getClassName}") ::
- Row(s"Usage: ${replaceFunctionName(info.getUsage(), info.getName)}") :: Nil
-
- if (isExtended) {
- result :+ Row(s"Extended Usage:\n${replaceFunctionName(info.getExtended, info.getName)}")
- } else {
- result
- }
+ // Hard code "<>", "!=", "between", and "case" for now as there is no corresponding functions.
+ functionName.toLowerCase match {
+ case "<>" =>
+ Row(s"Function: $functionName") ::
+ Row(s"Usage: a <> b - Returns TRUE if a is not equal to b") :: Nil
+ case "!=" =>
+ Row(s"Function: $functionName") ::
+ Row(s"Usage: a != b - Returns TRUE if a is not equal to b") :: Nil
+ case "between" =>
+ Row(s"Function: between") ::
+ Row(s"Usage: a [NOT] BETWEEN b AND c - " +
+ s"evaluate if a is [not] in between b and c") :: Nil
+ case "case" =>
+ Row(s"Function: case") ::
+ Row(s"Usage: CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END - " +
+ s"When a = b, returns c; when a = d, return e; else return f") :: Nil
+ case _ => sqlContext.sessionState.functionRegistry.lookupFunction(functionName) match {
+ case Some(info) =>
+ val result =
+ Row(s"Function: ${info.getName}") ::
+ Row(s"Class: ${info.getClassName}") ::
+ Row(s"Usage: ${replaceFunctionName(info.getUsage(), info.getName)}") :: Nil
+
+ if (isExtended) {
+ result :+
+ Row(s"Extended Usage:\n${replaceFunctionName(info.getExtended, info.getName)}")
+ } else {
+ result
+ }
- case None => Seq(Row(s"Function: $functionName not found."))
+ case None => Seq(Row(s"Function: $functionName not found."))
+ }
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index 373b557683..fc37a142cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -17,15 +17,19 @@
package org.apache.spark.sql.execution.command
+import scala.util.control.NonFatal
+
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable}
+import org.apache.spark.sql.catalyst.catalog.{CatalogTablePartition, CatalogTableType, SessionCatalog}
import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
-import org.apache.spark.sql.execution.datasources.BucketSpec
import org.apache.spark.sql.types._
+
// Note: The definition of these commands are based on the ones described in
// https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL
@@ -44,131 +48,379 @@ abstract class NativeDDLCommand(val sql: String) extends RunnableCommand {
}
+/**
+ * A command for users to create a new database.
+ *
+ * It will issue an error message when the database with the same name already exists,
+ * unless 'ifNotExists' is true.
+ * The syntax of using this command in SQL is:
+ * {{{
+ * CREATE DATABASE|SCHEMA [IF NOT EXISTS] database_name
+ * }}}
+ */
case class CreateDatabase(
databaseName: String,
ifNotExists: Boolean,
path: Option[String],
comment: Option[String],
- props: Map[String, String])(sql: String)
- extends NativeDDLCommand(sql) with Logging
+ props: Map[String, String])
+ extends RunnableCommand {
+
+ override def run(sqlContext: SQLContext): Seq[Row] = {
+ val catalog = sqlContext.sessionState.catalog
+ catalog.createDatabase(
+ CatalogDatabase(
+ databaseName,
+ comment.getOrElse(""),
+ path.getOrElse(catalog.getDefaultDBPath(databaseName)),
+ props),
+ ifNotExists)
+ Seq.empty[Row]
+ }
+
+ override val output: Seq[Attribute] = Seq.empty
+}
+
/**
- * Drop Database: Removes a database from the system.
+ * A command for users to remove a database from the system.
*
* 'ifExists':
* - true, if database_name does't exist, no action
* - false (default), if database_name does't exist, a warning message will be issued
- * 'restric':
- * - true (default), the database cannot be dropped if it is not empty. The inclusive
- * tables must be dropped at first.
- * - false, it is in the Cascade mode. The dependent objects are automatically dropped
- * before dropping database.
+ * 'cascade':
+ * - true, the dependent objects are automatically dropped before dropping database.
+ * - false (default), it is in the Restrict mode. The database cannot be dropped if
+ * it is not empty. The inclusive tables must be dropped at first.
+ *
+ * The syntax of using this command in SQL is:
+ * {{{
+ * DROP DATABASE [IF EXISTS] database_name [RESTRICT|CASCADE];
+ * }}}
*/
case class DropDatabase(
databaseName: String,
ifExists: Boolean,
- restrict: Boolean)(sql: String)
- extends NativeDDLCommand(sql) with Logging
+ cascade: Boolean)
+ extends RunnableCommand {
-case class CreateFunction(
- functionName: String,
- alias: String,
- resources: Seq[(String, String)],
- isTemp: Boolean)(sql: String)
- extends NativeDDLCommand(sql) with Logging
+ override def run(sqlContext: SQLContext): Seq[Row] = {
+ sqlContext.sessionState.catalog.dropDatabase(databaseName, ifExists, cascade)
+ Seq.empty[Row]
+ }
-case class AlterTableRename(
- oldName: TableIdentifier,
- newName: TableIdentifier)(sql: String)
- extends NativeDDLCommand(sql) with Logging
+ override val output: Seq[Attribute] = Seq.empty
+}
-case class AlterTableSetProperties(
+/**
+ * A command for users to add new (key, value) pairs into DBPROPERTIES
+ * If the database does not exist, an error message will be issued to indicate the database
+ * does not exist.
+ * The syntax of using this command in SQL is:
+ * {{{
+ * ALTER (DATABASE|SCHEMA) database_name SET DBPROPERTIES (property_name=property_value, ...)
+ * }}}
+ */
+case class AlterDatabaseProperties(
+ databaseName: String,
+ props: Map[String, String])
+ extends RunnableCommand {
+
+ override def run(sqlContext: SQLContext): Seq[Row] = {
+ val catalog = sqlContext.sessionState.catalog
+ val db: CatalogDatabase = catalog.getDatabaseMetadata(databaseName)
+ catalog.alterDatabase(db.copy(properties = db.properties ++ props))
+
+ Seq.empty[Row]
+ }
+
+ override val output: Seq[Attribute] = Seq.empty
+}
+
+/**
+ * A command for users to show the name of the database, its comment (if one has been set), and its
+ * root location on the filesystem. When extended is true, it also shows the database's properties
+ * If the database does not exist, an error message will be issued to indicate the database
+ * does not exist.
+ * The syntax of using this command in SQL is
+ * {{{
+ * DESCRIBE DATABASE [EXTENDED] db_name
+ * }}}
+ */
+case class DescribeDatabase(
+ databaseName: String,
+ extended: Boolean)
+ extends RunnableCommand {
+
+ override def run(sqlContext: SQLContext): Seq[Row] = {
+ val dbMetadata: CatalogDatabase =
+ sqlContext.sessionState.catalog.getDatabaseMetadata(databaseName)
+ val result =
+ Row("Database Name", dbMetadata.name) ::
+ Row("Description", dbMetadata.description) ::
+ Row("Location", dbMetadata.locationUri) :: Nil
+
+ if (extended) {
+ val properties =
+ if (dbMetadata.properties.isEmpty) {
+ ""
+ } else {
+ dbMetadata.properties.toSeq.mkString("(", ", ", ")")
+ }
+ result :+ Row("Properties", properties)
+ } else {
+ result
+ }
+ }
+
+ override val output: Seq[Attribute] = {
+ AttributeReference("database_description_item", StringType, nullable = false)() ::
+ AttributeReference("database_description_value", StringType, nullable = false)() :: Nil
+ }
+}
+
+/**
+ * Drops a table/view from the metastore and removes it if it is cached.
+ *
+ * The syntax of this command is:
+ * {{{
+ * DROP TABLE [IF EXISTS] table_name;
+ * DROP VIEW [IF EXISTS] [db_name.]view_name;
+ * }}}
+ */
+case class DropTable(
tableName: TableIdentifier,
- properties: Map[String, String])(sql: String)
- extends NativeDDLCommand(sql) with Logging
+ ifExists: Boolean,
+ isView: Boolean) extends RunnableCommand {
-case class AlterTableUnsetProperties(
+ override def run(sqlContext: SQLContext): Seq[Row] = {
+ val catalog = sqlContext.sessionState.catalog
+ if (!catalog.tableExists(tableName)) {
+ if (!ifExists) {
+ val objectName = if (isView) "View" else "Table"
+ logError(s"$objectName '${tableName.quotedString}' does not exist")
+ }
+ } else {
+ // If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view
+ // issue an exception.
+ catalog.getTableMetadataOption(tableName).map(_.tableType match {
+ case CatalogTableType.VIRTUAL_VIEW if !isView =>
+ throw new AnalysisException(
+ "Cannot drop a view with DROP TABLE. Please use DROP VIEW instead")
+ case o if o != CatalogTableType.VIRTUAL_VIEW && isView =>
+ throw new AnalysisException(
+ s"Cannot drop a table with DROP VIEW. Please use DROP TABLE instead")
+ case _ =>
+ })
+ try {
+ sqlContext.cacheManager.tryUncacheQuery(sqlContext.table(tableName.quotedString))
+ } catch {
+ case NonFatal(e) => log.warn(s"${e.getMessage}", e)
+ }
+ catalog.invalidateTable(tableName)
+ catalog.dropTable(tableName, ifExists)
+ }
+ Seq.empty[Row]
+ }
+}
+
+/**
+ * A command that sets table/view properties.
+ *
+ * The syntax of this command is:
+ * {{{
+ * ALTER TABLE table1 SET TBLPROPERTIES ('key1' = 'val1', 'key2' = 'val2', ...);
+ * ALTER VIEW view1 SET TBLPROPERTIES ('key1' = 'val1', 'key2' = 'val2', ...);
+ * }}}
+ */
+case class AlterTableSetProperties(
tableName: TableIdentifier,
properties: Map[String, String],
- ifExists: Boolean)(sql: String)
- extends NativeDDLCommand(sql) with Logging
+ isView: Boolean)
+ extends RunnableCommand {
-case class AlterTableSerDeProperties(
- tableName: TableIdentifier,
- serdeClassName: Option[String],
- serdeProperties: Option[Map[String, String]],
- partition: Option[Map[String, String]])(sql: String)
- extends NativeDDLCommand(sql) with Logging
+ override def run(sqlContext: SQLContext): Seq[Row] = {
+ val catalog = sqlContext.sessionState.catalog
+ DDLUtils.verifyAlterTableType(catalog, tableName, isView)
+ val table = catalog.getTableMetadata(tableName)
+ val newProperties = table.properties ++ properties
+ if (DDLUtils.isDatasourceTable(newProperties)) {
+ throw new AnalysisException(
+ "alter table properties is not supported for tables defined using the datasource API")
+ }
+ val newTable = table.copy(properties = newProperties)
+ catalog.alterTable(newTable)
+ Seq.empty[Row]
+ }
-case class AlterTableStorageProperties(
+}
+
+/**
+ * A command that unsets table/view properties.
+ *
+ * The syntax of this command is:
+ * {{{
+ * ALTER TABLE table1 UNSET TBLPROPERTIES [IF EXISTS] ('key1', 'key2', ...);
+ * ALTER VIEW view1 UNSET TBLPROPERTIES [IF EXISTS] ('key1', 'key2', ...);
+ * }}}
+ */
+case class AlterTableUnsetProperties(
tableName: TableIdentifier,
- buckets: BucketSpec)(sql: String)
- extends NativeDDLCommand(sql) with Logging
+ propKeys: Seq[String],
+ ifExists: Boolean,
+ isView: Boolean)
+ extends RunnableCommand {
-case class AlterTableNotClustered(
- tableName: TableIdentifier)(sql: String) extends NativeDDLCommand(sql) with Logging
+ override def run(sqlContext: SQLContext): Seq[Row] = {
+ val catalog = sqlContext.sessionState.catalog
+ DDLUtils.verifyAlterTableType(catalog, tableName, isView)
+ val table = catalog.getTableMetadata(tableName)
+ if (DDLUtils.isDatasourceTable(table)) {
+ throw new AnalysisException(
+ "alter table properties is not supported for datasource tables")
+ }
+ if (!ifExists) {
+ propKeys.foreach { k =>
+ if (!table.properties.contains(k)) {
+ throw new AnalysisException(
+ s"attempted to unset non-existent property '$k' in table '$tableName'")
+ }
+ }
+ }
+ val newProperties = table.properties.filter { case (k, _) => !propKeys.contains(k) }
+ val newTable = table.copy(properties = newProperties)
+ catalog.alterTable(newTable)
+ Seq.empty[Row]
+ }
-case class AlterTableNotSorted(
- tableName: TableIdentifier)(sql: String) extends NativeDDLCommand(sql) with Logging
+}
-case class AlterTableSkewed(
+/**
+ * A command that sets the serde class and/or serde properties of a table/view.
+ *
+ * The syntax of this command is:
+ * {{{
+ * ALTER TABLE table [PARTITION spec] SET SERDE serde_name [WITH SERDEPROPERTIES props];
+ * ALTER TABLE table [PARTITION spec] SET SERDEPROPERTIES serde_properties;
+ * }}}
+ */
+case class AlterTableSerDeProperties(
tableName: TableIdentifier,
- // e.g. (dt, country)
- skewedCols: Seq[String],
- // e.g. ('2008-08-08', 'us), ('2009-09-09', 'uk')
- skewedValues: Seq[Seq[String]],
- storedAsDirs: Boolean)(sql: String)
- extends NativeDDLCommand(sql) with Logging {
-
- require(skewedValues.forall(_.size == skewedCols.size),
- "number of columns in skewed values do not match number of skewed columns provided")
-}
+ serdeClassName: Option[String],
+ serdeProperties: Option[Map[String, String]],
+ partition: Option[Map[String, String]])
+ extends RunnableCommand {
-case class AlterTableNotSkewed(
- tableName: TableIdentifier)(sql: String) extends NativeDDLCommand(sql) with Logging
+ // should never happen if we parsed things correctly
+ require(serdeClassName.isDefined || serdeProperties.isDefined,
+ "alter table attempted to set neither serde class name nor serde properties")
-case class AlterTableNotStoredAsDirs(
- tableName: TableIdentifier)(sql: String) extends NativeDDLCommand(sql) with Logging
+ override def run(sqlContext: SQLContext): Seq[Row] = {
+ val catalog = sqlContext.sessionState.catalog
+ val table = catalog.getTableMetadata(tableName)
+ // Do not support setting serde for datasource tables
+ if (serdeClassName.isDefined && DDLUtils.isDatasourceTable(table)) {
+ throw new AnalysisException(
+ "alter table serde is not supported for datasource tables")
+ }
+ val newTable = table.withNewStorage(
+ serde = serdeClassName.orElse(table.storage.serde),
+ serdeProperties = table.storage.serdeProperties ++ serdeProperties.getOrElse(Map()))
+ catalog.alterTable(newTable)
+ Seq.empty[Row]
+ }
-case class AlterTableSkewedLocation(
- tableName: TableIdentifier,
- skewedMap: Map[String, String])(sql: String)
- extends NativeDDLCommand(sql) with Logging
+}
+/**
+ * Add Partition in ALTER TABLE: add the table partitions.
+ *
+ * 'partitionSpecsAndLocs': the syntax of ALTER VIEW is identical to ALTER TABLE,
+ * EXCEPT that it is ILLEGAL to specify a LOCATION clause.
+ * An error message will be issued if the partition exists, unless 'ifNotExists' is true.
+ *
+ * The syntax of this command is:
+ * {{{
+ * ALTER TABLE table ADD [IF NOT EXISTS] PARTITION spec [LOCATION 'loc1']
+ * }}}
+ */
case class AlterTableAddPartition(
tableName: TableIdentifier,
partitionSpecsAndLocs: Seq[(TablePartitionSpec, Option[String])],
- ifNotExists: Boolean)(sql: String)
- extends NativeDDLCommand(sql) with Logging
+ ifNotExists: Boolean)
+ extends RunnableCommand {
+ override def run(sqlContext: SQLContext): Seq[Row] = {
+ val catalog = sqlContext.sessionState.catalog
+ val table = catalog.getTableMetadata(tableName)
+ if (DDLUtils.isDatasourceTable(table)) {
+ throw new AnalysisException(
+ "alter table add partition is not allowed for tables defined using the datasource API")
+ }
+ val parts = partitionSpecsAndLocs.map { case (spec, location) =>
+ // inherit table storage format (possibly except for location)
+ CatalogTablePartition(spec, table.storage.copy(locationUri = location))
+ }
+ catalog.createPartitions(tableName, parts, ignoreIfExists = ifNotExists)
+ Seq.empty[Row]
+ }
+
+}
+
+/**
+ * Alter a table partition's spec.
+ *
+ * The syntax of this command is:
+ * {{{
+ * ALTER TABLE table PARTITION spec1 RENAME TO PARTITION spec2;
+ * }}}
+ */
case class AlterTableRenamePartition(
tableName: TableIdentifier,
oldPartition: TablePartitionSpec,
- newPartition: TablePartitionSpec)(sql: String)
- extends NativeDDLCommand(sql) with Logging
+ newPartition: TablePartitionSpec)
+ extends RunnableCommand {
-case class AlterTableExchangePartition(
- fromTableName: TableIdentifier,
- toTableName: TableIdentifier,
- spec: TablePartitionSpec)(sql: String)
- extends NativeDDLCommand(sql) with Logging
+ override def run(sqlContext: SQLContext): Seq[Row] = {
+ sqlContext.sessionState.catalog.renamePartitions(
+ tableName, Seq(oldPartition), Seq(newPartition))
+ Seq.empty[Row]
+ }
+}
+
+/**
+ * Drop Partition in ALTER TABLE: to drop a particular partition for a table.
+ *
+ * This removes the data and metadata for this partition.
+ * The data is actually moved to the .Trash/Current directory if Trash is configured,
+ * unless 'purge' is true, but the metadata is completely lost.
+ * An error message will be issued if the partition does not exist, unless 'ifExists' is true.
+ * Note: purge is always false when the target is a view.
+ *
+ * The syntax of this command is:
+ * {{{
+ * ALTER TABLE table DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE];
+ * }}}
+ */
case class AlterTableDropPartition(
tableName: TableIdentifier,
specs: Seq[TablePartitionSpec],
- ifExists: Boolean,
- purge: Boolean)(sql: String)
- extends NativeDDLCommand(sql) with Logging
+ ifExists: Boolean)
+ extends RunnableCommand {
-case class AlterTableArchivePartition(
- tableName: TableIdentifier,
- spec: TablePartitionSpec)(sql: String)
- extends NativeDDLCommand(sql) with Logging
+ override def run(sqlContext: SQLContext): Seq[Row] = {
+ val catalog = sqlContext.sessionState.catalog
+ val table = catalog.getTableMetadata(tableName)
+ if (DDLUtils.isDatasourceTable(table)) {
+ throw new AnalysisException(
+ "alter table drop partition is not allowed for tables defined using the datasource API")
+ }
+ catalog.dropPartitions(tableName, specs, ignoreIfNotExists = ifExists)
+ Seq.empty[Row]
+ }
-case class AlterTableUnarchivePartition(
- tableName: TableIdentifier,
- spec: TablePartitionSpec)(sql: String)
- extends NativeDDLCommand(sql) with Logging
+}
case class AlterTableSetFileFormat(
tableName: TableIdentifier,
@@ -177,27 +429,55 @@ case class AlterTableSetFileFormat(
genericFormat: Option[String])(sql: String)
extends NativeDDLCommand(sql) with Logging
+/**
+ * A command that sets the location of a table or a partition.
+ *
+ * For normal tables, this just sets the location URI in the table/partition's storage format.
+ * For datasource tables, this sets a "path" parameter in the table/partition's serde properties.
+ *
+ * The syntax of this command is:
+ * {{{
+ * ALTER TABLE table_name [PARTITION partition_spec] SET LOCATION "loc";
+ * }}}
+ */
case class AlterTableSetLocation(
tableName: TableIdentifier,
partitionSpec: Option[TablePartitionSpec],
- location: String)(sql: String)
- extends NativeDDLCommand(sql) with Logging
-
-case class AlterTableTouch(
- tableName: TableIdentifier,
- partitionSpec: Option[TablePartitionSpec])(sql: String)
- extends NativeDDLCommand(sql) with Logging
+ location: String)
+ extends RunnableCommand {
-case class AlterTableCompact(
- tableName: TableIdentifier,
- partitionSpec: Option[TablePartitionSpec],
- compactType: String)(sql: String)
- extends NativeDDLCommand(sql) with Logging
+ override def run(sqlContext: SQLContext): Seq[Row] = {
+ val catalog = sqlContext.sessionState.catalog
+ val table = catalog.getTableMetadata(tableName)
+ partitionSpec match {
+ case Some(spec) =>
+ // Partition spec is specified, so we set the location only for this partition
+ val part = catalog.getPartition(tableName, spec)
+ val newPart =
+ if (DDLUtils.isDatasourceTable(table)) {
+ throw new AnalysisException(
+ "alter table set location for partition is not allowed for tables defined " +
+ "using the datasource API")
+ } else {
+ part.copy(storage = part.storage.copy(locationUri = Some(location)))
+ }
+ catalog.alterPartitions(tableName, Seq(newPart))
+ case None =>
+ // No partition spec is specified, so we set the location for the table itself
+ val newTable =
+ if (DDLUtils.isDatasourceTable(table)) {
+ table.withNewStorage(
+ locationUri = Some(location),
+ serdeProperties = table.storage.serdeProperties ++ Map("path" -> location))
+ } else {
+ table.withNewStorage(locationUri = Some(location))
+ }
+ catalog.alterTable(newTable)
+ }
+ Seq.empty[Row]
+ }
-case class AlterTableMerge(
- tableName: TableIdentifier,
- partitionSpec: Option[TablePartitionSpec])(sql: String)
- extends NativeDDLCommand(sql) with Logging
+}
case class AlterTableChangeCol(
tableName: TableIdentifier,
@@ -226,3 +506,35 @@ case class AlterTableReplaceCol(
restrict: Boolean,
cascade: Boolean)(sql: String)
extends NativeDDLCommand(sql) with Logging
+
+
+private object DDLUtils {
+
+ def isDatasourceTable(props: Map[String, String]): Boolean = {
+ props.contains("spark.sql.sources.provider")
+ }
+
+ def isDatasourceTable(table: CatalogTable): Boolean = {
+ isDatasourceTable(table.properties)
+ }
+
+ /**
+ * If the command ALTER VIEW is to alter a table or ALTER TABLE is to alter a view,
+ * issue an exception [[AnalysisException]].
+ */
+ def verifyAlterTableType(
+ catalog: SessionCatalog,
+ tableIdentifier: TableIdentifier,
+ isView: Boolean): Unit = {
+ catalog.getTableMetadataOption(tableIdentifier).map(_.tableType match {
+ case CatalogTableType.VIRTUAL_VIEW if !isView =>
+ throw new AnalysisException(
+ "Cannot alter a view with ALTER TABLE. Please use ALTER VIEW instead")
+ case o if o != CatalogTableType.VIRTUAL_VIEW && isView =>
+ throw new AnalysisException(
+ s"Cannot alter a table with ALTER VIEW. Please use ALTER TABLE instead")
+ case _ =>
+ })
+ }
+}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala
new file mode 100644
index 0000000000..c6e601799f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.command
+
+import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
+import org.apache.spark.sql.catalyst.FunctionIdentifier
+import org.apache.spark.sql.catalyst.catalog.CatalogFunction
+import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
+
+
+/**
+ * The DDL command that creates a function.
+ * To create a temporary function, the syntax of using this command in SQL is:
+ * {{{
+ * CREATE TEMPORARY FUNCTION functionName
+ * AS className [USING JAR\FILE 'uri' [, JAR|FILE 'uri']]
+ * }}}
+ *
+ * To create a permanent function, the syntax in SQL is:
+ * {{{
+ * CREATE FUNCTION [databaseName.]functionName
+ * AS className [USING JAR\FILE 'uri' [, JAR|FILE 'uri']]
+ * }}}
+ */
+// TODO: Use Seq[FunctionResource] instead of Seq[(String, String)] for resources.
+case class CreateFunction(
+ databaseName: Option[String],
+ functionName: String,
+ className: String,
+ resources: Seq[(String, String)],
+ isTemp: Boolean)
+ extends RunnableCommand {
+
+ override def run(sqlContext: SQLContext): Seq[Row] = {
+ val catalog = sqlContext.sessionState.catalog
+ if (isTemp) {
+ if (databaseName.isDefined) {
+ throw new AnalysisException(
+ s"It is not allowed to provide database name when defining a temporary function. " +
+ s"However, database name ${databaseName.get} is provided.")
+ }
+ // We first load resources and then put the builder in the function registry.
+ // Please note that it is allowed to overwrite an existing temp function.
+ catalog.loadFunctionResources(resources)
+ val info = new ExpressionInfo(className, functionName)
+ val builder = catalog.makeFunctionBuilder(functionName, className)
+ catalog.createTempFunction(functionName, info, builder, ignoreIfExists = false)
+ } else {
+ // For a permanent, we will store the metadata into underlying external catalog.
+ // This function will be loaded into the FunctionRegistry when a query uses it.
+ // We do not load it into FunctionRegistry right now.
+ // TODO: should we also parse "IF NOT EXISTS"?
+ catalog.createFunction(
+ CatalogFunction(FunctionIdentifier(functionName, databaseName), className, resources),
+ ignoreIfExists = false)
+ }
+ Seq.empty[Row]
+ }
+}
+
+/**
+ * The DDL command that drops a function.
+ * ifExists: returns an error if the function doesn't exist, unless this is true.
+ * isTemp: indicates if it is a temporary function.
+ */
+case class DropFunction(
+ databaseName: Option[String],
+ functionName: String,
+ ifExists: Boolean,
+ isTemp: Boolean)
+ extends RunnableCommand {
+
+ override def run(sqlContext: SQLContext): Seq[Row] = {
+ val catalog = sqlContext.sessionState.catalog
+ if (isTemp) {
+ if (databaseName.isDefined) {
+ throw new AnalysisException(
+ s"It is not allowed to provide database name when dropping a temporary function. " +
+ s"However, database name ${databaseName.get} is provided.")
+ }
+ catalog.dropTempFunction(functionName, ifExists)
+ } else {
+ // We are dropping a permanent function.
+ catalog.dropFunction(
+ FunctionIdentifier(functionName, databaseName),
+ ignoreIfNotExists = ifExists)
+ }
+ Seq.empty[Row]
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
new file mode 100644
index 0000000000..0b41985174
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.command
+
+import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType}
+
+/**
+ * A command to create a table with the same definition of the given existing table.
+ *
+ * The syntax of using this command in SQL is:
+ * {{{
+ * CREATE TABLE [IF NOT EXISTS] [db_name.]table_name
+ * LIKE [other_db_name.]existing_table_name
+ * }}}
+ */
+case class CreateTableLike(
+ targetTable: TableIdentifier,
+ sourceTable: TableIdentifier,
+ ifNotExists: Boolean) extends RunnableCommand {
+
+ override def run(sqlContext: SQLContext): Seq[Row] = {
+ val catalog = sqlContext.sessionState.catalog
+ if (!catalog.tableExists(sourceTable)) {
+ throw new AnalysisException(
+ s"Source table in CREATE TABLE LIKE does not exist: '$sourceTable'")
+ }
+ if (catalog.isTemporaryTable(sourceTable)) {
+ throw new AnalysisException(
+ s"Source table in CREATE TABLE LIKE cannot be temporary: '$sourceTable'")
+ }
+
+ val tableToCreate = catalog.getTableMetadata(sourceTable).copy(
+ identifier = targetTable,
+ tableType = CatalogTableType.MANAGED_TABLE,
+ createTime = System.currentTimeMillis,
+ lastAccessTime = -1).withNewStorage(locationUri = None)
+
+ catalog.createTable(tableToCreate, ifNotExists)
+ Seq.empty[Row]
+ }
+}
+
+
+// TODO: move the rest of the table commands from ddl.scala to this file
+
+/**
+ * A command to create a table.
+ *
+ * Note: This is currently used only for creating Hive tables.
+ * This is not intended for temporary tables.
+ *
+ * The syntax of using this command in SQL is:
+ * {{{
+ * CREATE [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name
+ * [(col1 data_type [COMMENT col_comment], ...)]
+ * [COMMENT table_comment]
+ * [PARTITIONED BY (col3 data_type [COMMENT col_comment], ...)]
+ * [CLUSTERED BY (col1, ...) [SORTED BY (col1 [ASC|DESC], ...)] INTO num_buckets BUCKETS]
+ * [SKEWED BY (col1, col2, ...) ON ((col_value, col_value, ...), ...)
+ * [STORED AS DIRECTORIES]
+ * [ROW FORMAT row_format]
+ * [STORED AS file_format | STORED BY storage_handler_class [WITH SERDEPROPERTIES (...)]]
+ * [LOCATION path]
+ * [TBLPROPERTIES (property_name=property_value, ...)]
+ * [AS select_statement];
+ * }}}
+ */
+case class CreateTable(table: CatalogTable, ifNotExists: Boolean) extends RunnableCommand {
+
+ override def run(sqlContext: SQLContext): Seq[Row] = {
+ sqlContext.sessionState.catalog.createTable(table, ifNotExists)
+ Seq.empty[Row]
+ }
+
+}
+
+
+/**
+ * A command that renames a table/view.
+ *
+ * The syntax of this command is:
+ * {{{
+ * ALTER TABLE table1 RENAME TO table2;
+ * ALTER VIEW view1 RENAME TO view2;
+ * }}}
+ */
+case class AlterTableRename(
+ oldName: TableIdentifier,
+ newName: TableIdentifier,
+ isView: Boolean)
+ extends RunnableCommand {
+
+ override def run(sqlContext: SQLContext): Seq[Row] = {
+ val catalog = sqlContext.sessionState.catalog
+ DDLUtils.verifyAlterTableType(catalog, oldName, isView)
+ catalog.invalidateTable(oldName)
+ catalog.renameTable(oldName, newName)
+ Seq.empty[Row]
+ }
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index c66921f485..10fde152ab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -123,36 +123,58 @@ case class DataSource(
}
}
- /** Returns a source that can be used to continually read data. */
- def createSource(): Source = {
+ private def inferFileFormatSchema(format: FileFormat): StructType = {
+ val caseInsensitiveOptions = new CaseInsensitiveMap(options)
+ val allPaths = caseInsensitiveOptions.get("path")
+ val globbedPaths = allPaths.toSeq.flatMap { path =>
+ val hdfsPath = new Path(path)
+ val fs = hdfsPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
+ val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
+ SparkHadoopUtil.get.globPathIfNecessary(qualified)
+ }.toArray
+
+ val fileCatalog: FileCatalog = new HDFSFileCatalog(sqlContext, options, globbedPaths, None)
+ userSpecifiedSchema.orElse {
+ format.inferSchema(
+ sqlContext,
+ caseInsensitiveOptions,
+ fileCatalog.allFiles())
+ }.getOrElse {
+ throw new AnalysisException("Unable to infer schema. It must be specified manually.")
+ }
+ }
+
+ /** Returns the name and schema of the source that can be used to continually read data. */
+ def sourceSchema(): (String, StructType) = {
providingClass.newInstance() match {
case s: StreamSourceProvider =>
- s.createSource(sqlContext, userSpecifiedSchema, className, options)
+ s.sourceSchema(sqlContext, userSpecifiedSchema, className, options)
case format: FileFormat =>
val caseInsensitiveOptions = new CaseInsensitiveMap(options)
val path = caseInsensitiveOptions.getOrElse("path", {
throw new IllegalArgumentException("'path' is not specified")
})
- val metadataPath = caseInsensitiveOptions.getOrElse("metadataPath", s"$path/_metadata")
+ (s"FileSource[$path]", inferFileFormatSchema(format))
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"Data source $className does not support streamed reading")
+ }
+ }
- val allPaths = caseInsensitiveOptions.get("path")
- val globbedPaths = allPaths.toSeq.flatMap { path =>
- val hdfsPath = new Path(path)
- val fs = hdfsPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
- val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
- SparkHadoopUtil.get.globPathIfNecessary(qualified)
- }.toArray
+ /** Returns a source that can be used to continually read data. */
+ def createSource(metadataPath: String): Source = {
+ providingClass.newInstance() match {
+ case s: StreamSourceProvider =>
+ s.createSource(sqlContext, metadataPath, userSpecifiedSchema, className, options)
- val fileCatalog: FileCatalog = new HDFSFileCatalog(sqlContext, options, globbedPaths, None)
- val dataSchema = userSpecifiedSchema.orElse {
- format.inferSchema(
- sqlContext,
- caseInsensitiveOptions,
- fileCatalog.allFiles())
- }.getOrElse {
- throw new AnalysisException("Unable to infer schema. It must be specified manually.")
- }
+ case format: FileFormat =>
+ val caseInsensitiveOptions = new CaseInsensitiveMap(options)
+ val path = caseInsensitiveOptions.getOrElse("path", {
+ throw new IllegalArgumentException("'path' is not specified")
+ })
+
+ val dataSchema = inferFileFormatSchema(format)
def dataFrameBuilder(files: Array[String]): DataFrame = {
Dataset.ofRows(
@@ -299,6 +321,9 @@ case class DataSource(
"It must be specified manually")
}
+ val enrichedOptions =
+ format.prepareRead(sqlContext, caseInsensitiveOptions, fileCatalog.allFiles())
+
HadoopFsRelation(
sqlContext,
fileCatalog,
@@ -306,7 +331,7 @@ case class DataSource(
dataSchema = dataSchema.asNullable,
bucketSpec = bucketSpec,
format,
- options)
+ enrichedOptions)
case _ =>
throw new AnalysisException(
@@ -345,16 +370,6 @@ case class DataSource(
PartitioningUtils.validatePartitionColumnDataTypes(
data.schema, partitionColumns, caseSensitive)
- val equality =
- if (sqlContext.conf.caseSensitiveAnalysis) {
- org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
- } else {
- org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
- }
-
- val dataSchema = StructType(
- data.schema.filterNot(f => partitionColumns.exists(equality(_, f.name))))
-
// If we are appending to a table that already exists, make sure the partitioning matches
// up. If we fail to load the table for whatever reason, ignore the check.
if (mode == SaveMode.Append) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 52c8f3ef0b..ac3c52e901 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -19,10 +19,8 @@ package org.apache.spark.sql.execution.datasources
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.TaskContext
-import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
-import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD}
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
@@ -35,14 +33,10 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.DataSourceScan.{INPUT_PATHS, PUSHED_FILTERS}
-import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.ExecutedCommand
-import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVectorUtils}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
-import org.apache.spark.util.{SerializableConfiguration, Utils}
-import org.apache.spark.util.collection.BitSet
/**
* Replaces generic operations with specific variants that are designed to work with Spark
@@ -110,135 +104,8 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
filters,
(a, _) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray))) :: Nil
- // Scanning partitioned HadoopFsRelation
- case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation, _, _))
- if t.partitionSchema.nonEmpty =>
- // We divide the filter expressions into 3 parts
- val partitionColumns = AttributeSet(
- t.partitionSchema.map(c => l.output.find(_.name == c.name).get))
-
- // Only pruning the partition keys
- val partitionFilters = filters.filter(_.references.subsetOf(partitionColumns))
-
- // Only pushes down predicates that do not reference partition keys.
- val pushedFilters = filters.filter(_.references.intersect(partitionColumns).isEmpty)
-
- // Predicates with both partition keys and attributes
- val partitionAndNormalColumnFilters =
- filters.toSet -- partitionFilters.toSet -- pushedFilters.toSet
-
- val selectedPartitions = t.location.listFiles(partitionFilters)
-
- logInfo {
- val total = t.partitionSpec.partitions.length
- val selected = selectedPartitions.length
- val percentPruned = (1 - selected.toDouble / total.toDouble) * 100
- s"Selected $selected partitions out of $total, pruned $percentPruned% partitions."
- }
-
- // need to add projections from "partitionAndNormalColumnAttrs" in if it is not empty
- val partitionAndNormalColumnAttrs = AttributeSet(partitionAndNormalColumnFilters)
- val partitionAndNormalColumnProjs = if (partitionAndNormalColumnAttrs.isEmpty) {
- projects
- } else {
- (partitionAndNormalColumnAttrs ++ projects).toSeq
- }
-
- // Prune the buckets based on the pushed filters that do not contain partitioning key
- // since the bucketing key is not allowed to use the columns in partitioning key
- val bucketSet = getBuckets(pushedFilters, t.bucketSpec)
- val scan = buildPartitionedTableScan(
- l,
- partitionAndNormalColumnProjs,
- pushedFilters,
- bucketSet,
- t.partitionSpec.partitionColumns,
- selectedPartitions,
- t.options)
-
- // Add a Projection to guarantee the original projection:
- // this is because "partitionAndNormalColumnAttrs" may be different
- // from the original "projects", in elements or their ordering
-
- partitionAndNormalColumnFilters.reduceLeftOption(expressions.And).map(cf =>
- if (projects.isEmpty || projects == partitionAndNormalColumnProjs) {
- // if the original projection is empty, no need for the additional Project either
- execution.Filter(cf, scan)
- } else {
- execution.Project(projects, execution.Filter(cf, scan))
- }
- ).getOrElse(scan) :: Nil
-
- // TODO: The code for planning bucketed/unbucketed/partitioned/unpartitioned tables contains
- // a lot of duplication and produces overly complicated RDDs.
-
- // Scanning non-partitioned HadoopFsRelation
- case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation, _, _)) =>
- // See buildPartitionedTableScan for the reason that we need to create a shard
- // broadcast HadoopConf.
- val sharedHadoopConf = SparkHadoopUtil.get.conf
- val confBroadcast =
- t.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf))
-
- t.bucketSpec match {
- case Some(spec) if t.sqlContext.conf.bucketingEnabled =>
- val scanBuilder: (Seq[Attribute], Array[Filter]) => RDD[InternalRow] = {
- (requiredColumns: Seq[Attribute], filters: Array[Filter]) => {
- val bucketed =
- t.location
- .allFiles()
- .filterNot(_.getPath.getName startsWith "_")
- .groupBy { f =>
- BucketingUtils
- .getBucketId(f.getPath.getName)
- .getOrElse(sys.error(s"Invalid bucket file ${f.getPath}"))
- }
-
- val bucketedDataMap = bucketed.mapValues { bucketFiles =>
- t.fileFormat.buildInternalScan(
- t.sqlContext,
- t.dataSchema,
- requiredColumns.map(_.name).toArray,
- filters,
- None,
- bucketFiles,
- confBroadcast,
- t.options).coalesce(1)
- }
-
- val bucketedRDD = new UnionRDD(t.sqlContext.sparkContext,
- (0 until spec.numBuckets).map { bucketId =>
- bucketedDataMap.getOrElse(bucketId, t.sqlContext.emptyResult: RDD[InternalRow])
- })
- bucketedRDD
- }
- }
-
- pruneFilterProject(
- l,
- projects,
- filters,
- scanBuilder) :: Nil
-
- case _ =>
- pruneFilterProject(
- l,
- projects,
- filters,
- (a, f) =>
- t.fileFormat.buildInternalScan(
- t.sqlContext,
- t.dataSchema,
- a.map(_.name).toArray,
- f,
- None,
- t.location.allFiles(),
- confBroadcast,
- t.options)) :: Nil
- }
-
case l @ LogicalRelation(baseRelation: TableScan, _, _) =>
- execution.DataSourceScan(
+ execution.DataSourceScan.create(
l.output, toCatalystRDD(l, baseRelation.buildScan()), baseRelation) :: Nil
case i @ logical.InsertIntoTable(l @ LogicalRelation(t: InsertableRelation, _, _),
@@ -248,218 +115,6 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
case _ => Nil
}
- private def buildPartitionedTableScan(
- logicalRelation: LogicalRelation,
- projections: Seq[NamedExpression],
- filters: Seq[Expression],
- buckets: Option[BitSet],
- partitionColumns: StructType,
- partitions: Seq[Partition],
- options: Map[String, String]): SparkPlan = {
- val relation = logicalRelation.relation.asInstanceOf[HadoopFsRelation]
-
- // Because we are creating one RDD per partition, we need to have a shared HadoopConf.
- // Otherwise, the cost of broadcasting HadoopConf in every RDD will be high.
- val sharedHadoopConf = SparkHadoopUtil.get.conf
- val confBroadcast =
- relation.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf))
- val partitionColumnNames = partitionColumns.fieldNames.toSet
-
- // Now, we create a scan builder, which will be used by pruneFilterProject. This scan builder
- // will union all partitions and attach partition values if needed.
- val scanBuilder: (Seq[Attribute], Array[Filter]) => RDD[InternalRow] = {
- (requiredColumns: Seq[Attribute], filters: Array[Filter]) => {
-
- relation.bucketSpec match {
- case Some(spec) if relation.sqlContext.conf.bucketingEnabled =>
- val requiredDataColumns =
- requiredColumns.filterNot(c => partitionColumnNames.contains(c.name))
-
- // Builds RDD[Row]s for each selected partition.
- val perPartitionRows: Seq[(Int, RDD[InternalRow])] = partitions.flatMap {
- case Partition(partitionValues, files) =>
- val bucketed = files.groupBy { f =>
- BucketingUtils
- .getBucketId(f.getPath.getName)
- .getOrElse(sys.error(s"Invalid bucket file ${f.getPath}"))
- }
-
- bucketed.map { bucketFiles =>
- // Don't scan any partition columns to save I/O. Here we are being optimistic and
- // assuming partition columns data stored in data files are always consistent with
- // those partition values encoded in partition directory paths.
- val dataRows = relation.fileFormat.buildInternalScan(
- relation.sqlContext,
- relation.dataSchema,
- requiredDataColumns.map(_.name).toArray,
- filters,
- buckets,
- bucketFiles._2,
- confBroadcast,
- options)
-
- // Merges data values with partition values.
- bucketFiles._1 -> mergeWithPartitionValues(
- requiredColumns,
- requiredDataColumns,
- partitionColumns,
- partitionValues,
- dataRows)
- }
- }
-
- val bucketedDataMap: Map[Int, Seq[RDD[InternalRow]]] =
- perPartitionRows.groupBy(_._1).mapValues(_.map(_._2))
-
- val bucketed = new UnionRDD(relation.sqlContext.sparkContext,
- (0 until spec.numBuckets).map { bucketId =>
- bucketedDataMap.get(bucketId).map(i => i.reduce(_ ++ _).coalesce(1)).getOrElse {
- relation.sqlContext.emptyResult: RDD[InternalRow]
- }
- })
- bucketed
-
- case _ =>
- val requiredDataColumns =
- requiredColumns.filterNot(c => partitionColumnNames.contains(c.name))
-
- // Builds RDD[Row]s for each selected partition.
- val perPartitionRows = partitions.map {
- case Partition(partitionValues, files) =>
- val dataRows = relation.fileFormat.buildInternalScan(
- relation.sqlContext,
- relation.dataSchema,
- requiredDataColumns.map(_.name).toArray,
- filters,
- buckets,
- files,
- confBroadcast,
- options)
-
- // Merges data values with partition values.
- mergeWithPartitionValues(
- requiredColumns,
- requiredDataColumns,
- partitionColumns,
- partitionValues,
- dataRows)
- }
- new UnionRDD(relation.sqlContext.sparkContext, perPartitionRows)
- }
- }
- }
-
- // Create the scan operator. If needed, add Filter and/or Project on top of the scan.
- // The added Filter/Project is on top of the unioned RDD. We do not want to create
- // one Filter/Project for every partition.
- val sparkPlan = pruneFilterProject(
- logicalRelation,
- projections,
- filters,
- scanBuilder)
-
- sparkPlan
- }
-
- /**
- * Creates a ColumnarBatch that contains the values for `requiredColumns`. These columns can
- * either come from `input` (columns scanned from the data source) or from the partitioning
- * values (data from `partitionValues`). This is done *once* per physical partition. When
- * the column is from `input`, it just references the same underlying column. When using
- * partition columns, the column is populated once.
- * TODO: there's probably a cleaner way to do this.
- */
- private def projectedColumnBatch(
- input: ColumnarBatch,
- requiredColumns: Seq[Attribute],
- dataColumns: Seq[Attribute],
- partitionColumnSchema: StructType,
- partitionValues: InternalRow) : ColumnarBatch = {
- val result = ColumnarBatch.allocate(StructType.fromAttributes(requiredColumns))
- var resultIdx = 0
- var inputIdx = 0
-
- while (resultIdx < requiredColumns.length) {
- val attr = requiredColumns(resultIdx)
- if (inputIdx < dataColumns.length && requiredColumns(resultIdx) == dataColumns(inputIdx)) {
- result.setColumn(resultIdx, input.column(inputIdx))
- inputIdx += 1
- } else {
- require(partitionColumnSchema.fields.count(_.name == attr.name) == 1)
- var partitionIdx = 0
- partitionColumnSchema.fields.foreach { f => {
- if (f.name.equals(attr.name)) {
- ColumnVectorUtils.populate(result.column(resultIdx), partitionValues, partitionIdx)
- }
- partitionIdx += 1
- }}
- }
- resultIdx += 1
- }
- result
- }
-
- private def mergeWithPartitionValues(
- requiredColumns: Seq[Attribute],
- dataColumns: Seq[Attribute],
- partitionColumnSchema: StructType,
- partitionValues: InternalRow,
- dataRows: RDD[InternalRow]): RDD[InternalRow] = {
- // If output columns contain any partition column(s), we need to merge scanned data
- // columns and requested partition columns to form the final result.
- if (requiredColumns != dataColumns) {
- // Builds `AttributeReference`s for all partition columns so that we can use them to project
- // required partition columns. Note that if a partition column appears in `requiredColumns`,
- // we should use the `AttributeReference` in `requiredColumns`.
- val partitionColumns = {
- val requiredColumnMap = requiredColumns.map(a => a.name -> a).toMap
- partitionColumnSchema.toAttributes.map { a =>
- requiredColumnMap.getOrElse(a.name, a)
- }
- }
-
- val mapPartitionsFunc = (_: TaskContext, _: Int, iterator: Iterator[Object]) => {
- // Note that we can't use an `UnsafeRowJoiner` to replace the following `JoinedRow` and
- // `UnsafeProjection`. Because the projection may also adjust column order.
- val mutableJoinedRow = new JoinedRow()
- val unsafePartitionValues = UnsafeProjection.create(partitionColumnSchema)(partitionValues)
- val unsafeProjection =
- UnsafeProjection.create(requiredColumns, dataColumns ++ partitionColumns)
-
- // If we are returning batches directly, we need to augment them with the partitioning
- // columns. We want to do this without a row by row operation.
- var columnBatch: ColumnarBatch = null
- var mergedBatch: ColumnarBatch = null
-
- iterator.map { input => {
- if (input.isInstanceOf[InternalRow]) {
- unsafeProjection(mutableJoinedRow(
- input.asInstanceOf[InternalRow], unsafePartitionValues))
- } else {
- require(input.isInstanceOf[ColumnarBatch])
- val inputBatch = input.asInstanceOf[ColumnarBatch]
- if (inputBatch != mergedBatch) {
- mergedBatch = inputBatch
- columnBatch = projectedColumnBatch(inputBatch, requiredColumns,
- dataColumns, partitionColumnSchema, partitionValues)
- }
- columnBatch.setNumRows(inputBatch.numRows())
- columnBatch
- }
- }}
- }
-
- // This is an internal RDD whose call site the user should not be concerned with
- // Since we create many of these (one per partition), the time spent on computing
- // the call site may add up.
- Utils.withDummyCallSite(dataRows.sparkContext) {
- new MapPartitionsRDD(dataRows, mapPartitionsFunc, preservesPartitioning = false)
- }.asInstanceOf[RDD[InternalRow]]
- } else {
- dataRows
- }
- }
-
// Get the bucket ID based on the bucketing values.
// Restriction: Bucket pruning works iff the bucketing column has one and only one column.
def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = {
@@ -472,57 +127,6 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
bucketIdGeneration(mutableRow).getInt(0)
}
- // Get the bucket BitSet by reading the filters that only contains bucketing keys.
- // Note: When the returned BitSet is None, no pruning is possible.
- // Restriction: Bucket pruning works iff the bucketing column has one and only one column.
- private def getBuckets(
- filters: Seq[Expression],
- bucketSpec: Option[BucketSpec]): Option[BitSet] = {
-
- if (bucketSpec.isEmpty ||
- bucketSpec.get.numBuckets == 1 ||
- bucketSpec.get.bucketColumnNames.length != 1) {
- // None means all the buckets need to be scanned
- return None
- }
-
- // Just get the first because bucketing pruning only works when the column has one column
- val bucketColumnName = bucketSpec.get.bucketColumnNames.head
- val numBuckets = bucketSpec.get.numBuckets
- val matchedBuckets = new BitSet(numBuckets)
- matchedBuckets.clear()
-
- filters.foreach {
- case expressions.EqualTo(a: Attribute, Literal(v, _)) if a.name == bucketColumnName =>
- matchedBuckets.set(getBucketId(a, numBuckets, v))
- case expressions.EqualTo(Literal(v, _), a: Attribute) if a.name == bucketColumnName =>
- matchedBuckets.set(getBucketId(a, numBuckets, v))
- case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) if a.name == bucketColumnName =>
- matchedBuckets.set(getBucketId(a, numBuckets, v))
- case expressions.EqualNullSafe(Literal(v, _), a: Attribute) if a.name == bucketColumnName =>
- matchedBuckets.set(getBucketId(a, numBuckets, v))
- // Because we only convert In to InSet in Optimizer when there are more than certain
- // items. So it is possible we still get an In expression here that needs to be pushed
- // down.
- case expressions.In(a: Attribute, list)
- if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName =>
- val hSet = list.map(e => e.eval(EmptyRow))
- hSet.foreach(e => matchedBuckets.set(getBucketId(a, numBuckets, e)))
- case expressions.IsNull(a: Attribute) if a.name == bucketColumnName =>
- matchedBuckets.set(getBucketId(a, numBuckets, null))
- case _ =>
- }
-
- logInfo {
- val selected = matchedBuckets.cardinality()
- val percentPruned = (1 - selected.toDouble / numBuckets.toDouble) * 100
- s"Selected $selected buckets out of $numBuckets, pruned $percentPruned% partitions."
- }
-
- // None means all the buckets need to be scanned
- if (matchedBuckets.cardinality() == 0) None else Some(matchedBuckets)
- }
-
// Based on Public API.
protected def pruneFilterProject(
relation: LogicalRelation,
@@ -610,7 +214,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
// Don't request columns that are only referenced by pushed filters.
.filterNot(handledSet.contains)
- val scan = execution.DataSourceScan(
+ val scan = execution.DataSourceScan.create(
projects.map(_.toAttribute),
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation, metadata)
@@ -620,7 +224,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
val requestedColumns =
(projectSet ++ filterSet -- handledSet).map(relation.attributeMap).toSeq
- val scan = execution.DataSourceScan(
+ val scan = execution.DataSourceScan.create(
requestedColumns,
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation, metadata)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
index 988c785dbe..468e101fed 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.datasources
import org.apache.spark.{Partition, TaskContext}
-import org.apache.spark.rdd.{RDD, SqlNewHadoopRDDState}
+import org.apache.spark.rdd.{InputFileNameHolder, RDD}
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
@@ -37,7 +37,6 @@ case class PartitionedFile(
}
}
-
/**
* A collection of files that should be read as a single task possibly from multiple partitioned
* directories.
@@ -50,7 +49,7 @@ class FileScanRDD(
@transient val sqlContext: SQLContext,
readFunction: (PartitionedFile) => Iterator[InternalRow],
@transient val filePartitions: Seq[FilePartition])
- extends RDD[InternalRow](sqlContext.sparkContext, Nil) {
+ extends RDD[InternalRow](sqlContext.sparkContext, Nil) {
override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
val iterator = new Iterator[Object] with AutoCloseable {
@@ -65,17 +64,17 @@ class FileScanRDD(
if (files.hasNext) {
val nextFile = files.next()
logInfo(s"Reading File $nextFile")
- SqlNewHadoopRDDState.setInputFileName(nextFile.filePath)
+ InputFileNameHolder.setInputFileName(nextFile.filePath)
currentIterator = readFunction(nextFile)
hasNext
} else {
- SqlNewHadoopRDDState.unsetInputFileName()
+ InputFileNameHolder.unsetInputFileName()
false
}
}
override def close() = {
- SqlNewHadoopRDDState.unsetInputFileName()
+ InputFileNameHolder.unsetInputFileName()
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
index 4b04fec57d..80a9156ddc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
@@ -55,11 +55,7 @@ import org.apache.spark.sql.sources._
*/
private[sql] object FileSourceStrategy extends Strategy with Logging {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case PhysicalOperation(projects, filters, l @ LogicalRelation(files: HadoopFsRelation, _, _))
- if (files.fileFormat.toString == "TestFileFormat" ||
- files.fileFormat.isInstanceOf[parquet.DefaultSource] ||
- files.fileFormat.toString == "ORC") &&
- files.sqlContext.conf.parquetFileScan =>
+ case PhysicalOperation(projects, filters, l @ LogicalRelation(files: HadoopFsRelation, _, _)) =>
// Filters on this relation fall into four categories based on where we can use them to avoid
// reading unneeded data:
// - partition keys only - used to prune directories to read
@@ -68,26 +64,28 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
// - filters that need to be evaluated again after the scan
val filterSet = ExpressionSet(filters)
+ // The attribute name of predicate could be different than the one in schema in case of
+ // case insensitive, we should change them to match the one in schema, so we donot need to
+ // worry about case sensitivity anymore.
+ val normalizedFilters = filters.map { e =>
+ e transform {
+ case a: AttributeReference =>
+ a.withName(l.output.find(_.semanticEquals(a)).get.name)
+ }
+ }
+
val partitionColumns =
l.resolve(files.partitionSchema, files.sqlContext.sessionState.analyzer.resolver)
val partitionSet = AttributeSet(partitionColumns)
val partitionKeyFilters =
- ExpressionSet(filters.filter(_.references.subsetOf(partitionSet)))
+ ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet)))
logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}")
val dataColumns =
l.resolve(files.dataSchema, files.sqlContext.sessionState.analyzer.resolver)
- val bucketColumns =
- AttributeSet(
- files.bucketSpec
- .map(_.bucketColumnNames)
- .getOrElse(Nil)
- .map(l.resolveQuoted(_, files.sqlContext.conf.resolver)
- .getOrElse(sys.error(""))))
-
// Partition keys are not available in the statistics of the files.
- val dataFilters = filters.filter(_.references.intersect(partitionSet).isEmpty)
+ val dataFilters = normalizedFilters.filter(_.references.intersect(partitionSet).isEmpty)
// Predicates with both partition keys and attributes need to be evaluated after the scan.
val afterScanFilters = filterSet -- partitionKeyFilters
@@ -111,8 +109,9 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
val readFile = files.fileFormat.buildReader(
sqlContext = files.sqlContext,
+ dataSchema = files.dataSchema,
partitionSchema = files.partitionSchema,
- dataSchema = prunedDataSchema,
+ requiredSchema = prunedDataSchema,
filters = pushedDownFilters,
options = files.options)
@@ -134,11 +133,12 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
case _ =>
val maxSplitBytes = files.sqlContext.conf.filesMaxPartitionBytes
- logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes")
+ val openCostInBytes = files.sqlContext.conf.filesOpenCostInBytes
+ logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " +
+ s"open cost is considered as scanning $openCostInBytes bytes.")
val splitFiles = selectedPartitions.flatMap { partition =>
partition.files.flatMap { file =>
- assert(file.getLen != 0, file.toString)
(0L to file.getLen by maxSplitBytes).map { offset =>
val remaining = file.getLen - offset
val size = if (remaining > maxSplitBytes) maxSplitBytes else remaining
@@ -153,7 +153,7 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
/** Add the given file to the current partition. */
def addFile(file: PartitionedFile): Unit = {
- currentSize += file.length
+ currentSize += file.length + openCostInBytes
currentFiles.append(file)
}
@@ -175,17 +175,15 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
splitFiles.foreach { file =>
if (currentSize + file.length > maxSplitBytes) {
closePartition()
- addFile(file)
- } else {
- addFile(file)
}
+ addFile(file)
}
closePartition()
partitions
}
val scan =
- DataSourceScan(
+ DataSourceScan.create(
readDataColumns ++ partitionColumns,
new FileScanRDD(
files.sqlContext,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala
new file mode 100644
index 0000000000..18f9b55895
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import java.net.URI
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.mapreduce._
+import org.apache.hadoop.mapreduce.lib.input.{FileSplit, LineRecordReader}
+import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
+
+/**
+ * An adaptor from a [[PartitionedFile]] to an [[Iterator]] of [[Text]], which are all of the lines
+ * in that file.
+ */
+class HadoopFileLinesReader(file: PartitionedFile, conf: Configuration) extends Iterator[Text] {
+ private val iterator = {
+ val fileSplit = new FileSplit(
+ new Path(new URI(file.filePath)),
+ file.start,
+ file.length,
+ // TODO: Implement Locality
+ Array.empty)
+ val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
+ val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+ val reader = new LineRecordReader()
+ reader.initialize(fileSplit, hadoopAttemptContext)
+ new RecordReaderIterator(reader)
+ }
+
+ override def hasNext: Boolean = iterator.hasNext
+
+ override def next(): Text = iterator.next()
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala
index e31380e17d..889c0204f8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala
@@ -32,7 +32,6 @@ import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
-import org.apache.spark.util.Utils
/**
* A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
deleted file mode 100644
index f3514cd14c..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
+++ /dev/null
@@ -1,314 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import java.text.SimpleDateFormat
-import java.util.Date
-
-import scala.reflect.ClassTag
-
-import org.apache.hadoop.conf.{Configurable, Configuration}
-import org.apache.hadoop.io.Writable
-import org.apache.hadoop.mapreduce._
-import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit}
-import org.apache.hadoop.mapreduce.task.{JobContextImpl, TaskAttemptContextImpl}
-
-import org.apache.spark.{Partition => SparkPartition, _}
-import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.executor.DataReadMethod
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader
-import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager}
-
-private[spark] class SqlNewHadoopPartition(
- rddId: Int,
- val index: Int,
- rawSplit: InputSplit with Writable)
- extends SparkPartition {
-
- val serializableHadoopSplit = new SerializableWritable(rawSplit)
-
- override def hashCode(): Int = 41 * (41 + rddId) + index
-}
-
-/**
- * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS,
- * sources in HBase, or S3), using the new MapReduce API (`org.apache.hadoop.mapreduce`).
- * It is based on [[org.apache.spark.rdd.NewHadoopRDD]]. It has three additions.
- * 1. A shared broadcast Hadoop Configuration.
- * 2. An optional closure `initDriverSideJobFuncOpt` that set configurations at the driver side
- * to the shared Hadoop Configuration.
- * 3. An optional closure `initLocalJobFuncOpt` that set configurations at both the driver side
- * and the executor side to the shared Hadoop Configuration.
- *
- * Note: This is RDD is basically a cloned version of [[org.apache.spark.rdd.NewHadoopRDD]] with
- * changes based on [[org.apache.spark.rdd.HadoopRDD]].
- */
-private[spark] class SqlNewHadoopRDD[V: ClassTag](
- sqlContext: SQLContext,
- broadcastedConf: Broadcast[SerializableConfiguration],
- @transient private val initDriverSideJobFuncOpt: Option[Job => Unit],
- initLocalJobFuncOpt: Option[Job => Unit],
- inputFormatClass: Class[_ <: InputFormat[Void, V]],
- valueClass: Class[V])
- extends RDD[V](sqlContext.sparkContext, Nil) with Logging {
-
- protected def getJob(): Job = {
- val conf = broadcastedConf.value.value
- // "new Job" will make a copy of the conf. Then, it is
- // safe to mutate conf properties with initLocalJobFuncOpt
- // and initDriverSideJobFuncOpt.
- val newJob = Job.getInstance(conf)
- initLocalJobFuncOpt.map(f => f(newJob))
- newJob
- }
-
- def getConf(isDriverSide: Boolean): Configuration = {
- val job = getJob()
- if (isDriverSide) {
- initDriverSideJobFuncOpt.map(f => f(job))
- }
- job.getConfiguration
- }
-
- private val jobTrackerId: String = {
- val formatter = new SimpleDateFormat("yyyyMMddHHmm")
- formatter.format(new Date())
- }
-
- @transient protected val jobId = new JobID(jobTrackerId, id)
-
- // If true, enable using the custom RecordReader for parquet. This only works for
- // a subset of the types (no complex types).
- protected val enableVectorizedParquetReader: Boolean =
- sqlContext.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean
- protected val enableWholestageCodegen: Boolean =
- sqlContext.getConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key).toBoolean
-
- override def getPartitions: Array[SparkPartition] = {
- val conf = getConf(isDriverSide = true)
- val inputFormat = inputFormatClass.newInstance
- inputFormat match {
- case configurable: Configurable =>
- configurable.setConf(conf)
- case _ =>
- }
- val jobContext = new JobContextImpl(conf, jobId)
- val rawSplits = inputFormat.getSplits(jobContext).toArray
- val result = new Array[SparkPartition](rawSplits.size)
- for (i <- 0 until rawSplits.size) {
- result(i) =
- new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable])
- }
- result
- }
-
- override def compute(
- theSplit: SparkPartition,
- context: TaskContext): Iterator[V] = {
- val iter = new Iterator[V] {
- val split = theSplit.asInstanceOf[SqlNewHadoopPartition]
- logInfo("Input split: " + split.serializableHadoopSplit)
- val conf = getConf(isDriverSide = false)
-
- val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop)
- val existingBytesRead = inputMetrics.bytesRead
-
- // Sets the thread local variable for the file's name
- split.serializableHadoopSplit.value match {
- case fs: FileSplit => SqlNewHadoopRDDState.setInputFileName(fs.getPath.toString)
- case _ => SqlNewHadoopRDDState.unsetInputFileName()
- }
-
- // Find a function that will return the FileSystem bytes read by this thread. Do this before
- // creating RecordReader, because RecordReader's constructor might read some bytes
- val getBytesReadCallback: Option[() => Long] = split.serializableHadoopSplit.value match {
- case _: FileSplit | _: CombineFileSplit =>
- SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
- case _ => None
- }
-
- // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics.
- // If we do a coalesce, however, we are likely to compute multiple partitions in the same
- // task and in the same thread, in which case we need to avoid override values written by
- // previous partitions (SPARK-13071).
- def updateBytesRead(): Unit = {
- getBytesReadCallback.foreach { getBytesRead =>
- inputMetrics.setBytesRead(existingBytesRead + getBytesRead())
- }
- }
-
- val format = inputFormatClass.newInstance
- format match {
- case configurable: Configurable =>
- configurable.setConf(conf)
- case _ =>
- }
- val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0)
- val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
- private[this] var reader: RecordReader[Void, V] = null
-
- /**
- * If the format is ParquetInputFormat, try to create the optimized RecordReader. If this
- * fails (for example, unsupported schema), try with the normal reader.
- * TODO: plumb this through a different way?
- */
- if (enableVectorizedParquetReader &&
- format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") {
- val parquetReader: VectorizedParquetRecordReader = new VectorizedParquetRecordReader()
- if (!parquetReader.tryInitialize(
- split.serializableHadoopSplit.value, hadoopAttemptContext)) {
- parquetReader.close()
- } else {
- reader = parquetReader.asInstanceOf[RecordReader[Void, V]]
- parquetReader.resultBatch()
- // Whole stage codegen (PhysicalRDD) is able to deal with batches directly
- if (enableWholestageCodegen) parquetReader.enableReturningBatches()
- }
- }
-
- if (reader == null) {
- reader = format.createRecordReader(
- split.serializableHadoopSplit.value, hadoopAttemptContext)
- reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
- }
-
- // Register an on-task-completion callback to close the input stream.
- context.addTaskCompletionListener(context => close())
-
- private[this] var havePair = false
- private[this] var finished = false
-
- override def hasNext: Boolean = {
- if (context.isInterrupted()) {
- throw new TaskKilledException
- }
- if (!finished && !havePair) {
- finished = !reader.nextKeyValue
- if (finished) {
- // Close and release the reader here; close() will also be called when the task
- // completes, but for tasks that read from many files, it helps to release the
- // resources early.
- close()
- }
- havePair = !finished
- }
- !finished
- }
-
- override def next(): V = {
- if (!hasNext) {
- throw new java.util.NoSuchElementException("End of stream")
- }
- havePair = false
- if (!finished) {
- inputMetrics.incRecordsReadInternal(1)
- }
- if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) {
- updateBytesRead()
- }
- reader.getCurrentValue
- }
-
- private def close() {
- if (reader != null) {
- SqlNewHadoopRDDState.unsetInputFileName()
- // Close the reader and release it. Note: it's very important that we don't close the
- // reader more than once, since that exposes us to MAPREDUCE-5918 when running against
- // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic
- // corruption issues when reading compressed input.
- try {
- reader.close()
- } catch {
- case e: Exception =>
- if (!ShutdownHookManager.inShutdown()) {
- logWarning("Exception in RecordReader.close()", e)
- }
- } finally {
- reader = null
- }
- if (getBytesReadCallback.isDefined) {
- updateBytesRead()
- } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] ||
- split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) {
- // If we can't get the bytes read from the FS stats, fall back to the split size,
- // which may be inaccurate.
- try {
- inputMetrics.incBytesReadInternal(split.serializableHadoopSplit.value.getLength)
- } catch {
- case e: java.io.IOException =>
- logWarning("Unable to get input size to set InputMetrics for task", e)
- }
- }
- }
- }
- }
- iter
- }
-
- override def getPreferredLocations(hsplit: SparkPartition): Seq[String] = {
- val split = hsplit.asInstanceOf[SqlNewHadoopPartition].serializableHadoopSplit.value
- val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match {
- case Some(c) =>
- try {
- val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]]
- Some(HadoopRDD.convertSplitLocationInfo(infos))
- } catch {
- case e : Exception =>
- logDebug("Failed to use InputSplit#getLocationInfo.", e)
- None
- }
- case None => None
- }
- locs.getOrElse(split.getLocations.filter(_ != "localhost"))
- }
-
- override def persist(storageLevel: StorageLevel): this.type = {
- if (storageLevel.deserialized) {
- logWarning("Caching NewHadoopRDDs as deserialized objects usually leads to undesired" +
- " behavior because Hadoop's RecordReader reuses the same Writable object for all records." +
- " Use a map transformation to make copies of the records.")
- }
- super.persist(storageLevel)
- }
-
- /**
- * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to
- * the given function rather than the index of the partition.
- */
- private[spark] class NewHadoopMapPartitionsWithSplitRDD[U: ClassTag, T: ClassTag](
- prev: RDD[T],
- f: (InputSplit, Iterator[T]) => Iterator[U],
- preservesPartitioning: Boolean = false)
- extends RDD[U](prev) {
-
- override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None
-
- override def getPartitions: Array[SparkPartition] = firstParent[T].partitions
-
- override def compute(split: SparkPartition, context: TaskContext): Iterator[U] = {
- val partition = split.asInstanceOf[SqlNewHadoopPartition]
- val inputSplit = partition.serializableHadoopSplit.value
- f(inputSplit, firstParent[T].iterator(split, context))
- }
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
index 233ac263aa..815d1d01ef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
@@ -33,9 +33,9 @@ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.UnsafeKVExternalSorter
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory}
+import org.apache.spark.sql.sources.{OutputWriter, OutputWriterFactory}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
-import org.apache.spark.util.SerializableConfiguration
+import org.apache.spark.util.{SerializableConfiguration, Utils}
/** A container for all the details required when writing to a table. */
case class WriteRelation(
@@ -129,16 +129,17 @@ private[sql] abstract class BaseWriterContainer(
outputWriterFactory.newInstance(path, bucketId, dataSchema, taskAttemptContext)
} catch {
case e: org.apache.hadoop.fs.FileAlreadyExistsException =>
- if (outputCommitter.isInstanceOf[parquet.DirectParquetOutputCommitter]) {
- // Spark-11382: DirectParquetOutputCommitter is not idempotent, meaning on retry
+ if (outputCommitter.getClass.getName.contains("Direct")) {
+ // SPARK-11382: DirectParquetOutputCommitter is not idempotent, meaning on retry
// attempts, the task will fail because the output file is created from a prior attempt.
// This often means the most visible error to the user is misleading. Augment the error
// to tell the user to look for the actual error.
throw new SparkException("The output file already exists but this could be due to a " +
"failure from an earlier attempt. Look through the earlier logs or stage page for " +
- "the first error.\n File exists error: " + e)
+ "the first error.\n File exists error: " + e, e)
+ } else {
+ throw e
}
- throw e
}
}
@@ -156,15 +157,6 @@ private[sql] abstract class BaseWriterContainer(
s"Using default output committer ${defaultOutputCommitter.getClass.getCanonicalName} " +
"for appending.")
defaultOutputCommitter
- } else if (speculationEnabled) {
- // When speculation is enabled, it's not safe to use customized output committer classes,
- // especially direct output committers (e.g. `DirectParquetOutputCommitter`).
- //
- // See SPARK-9899 for more details.
- logInfo(
- s"Using default output committer ${defaultOutputCommitter.getClass.getCanonicalName} " +
- "because spark.speculation is configured to be true.")
- defaultOutputCommitter
} else {
val configuration = context.getConfiguration
val committerClass = configuration.getClass(
@@ -255,19 +247,16 @@ private[sql] class DefaultWriterContainer(
// If anything below fails, we should abort the task.
try {
- while (iterator.hasNext) {
- val internalRow = iterator.next()
- writer.writeInternal(internalRow)
- }
-
- commitTask()
+ Utils.tryWithSafeFinallyAndFailureCallbacks {
+ while (iterator.hasNext) {
+ val internalRow = iterator.next()
+ writer.writeInternal(internalRow)
+ }
+ commitTask()
+ }(catchBlock = abortTask())
} catch {
- case cause: Throwable =>
- logError("Aborting task.", cause)
- // call failure callbacks first, so we could have a chance to cleanup the writer.
- TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause)
- abortTask()
- throw new SparkException("Task failed while writing rows.", cause)
+ case t: Throwable =>
+ throw new SparkException("Task failed while writing rows", t)
}
def commitTask(): Unit = {
@@ -421,37 +410,37 @@ private[sql] class DynamicPartitionWriterContainer(
// If anything below fails, we should abort the task.
var currentWriter: OutputWriter = null
try {
- var currentKey: UnsafeRow = null
- while (sortedIterator.next()) {
- val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
- if (currentKey != nextKey) {
- if (currentWriter != null) {
- currentWriter.close()
- currentWriter = null
+ Utils.tryWithSafeFinallyAndFailureCallbacks {
+ var currentKey: UnsafeRow = null
+ while (sortedIterator.next()) {
+ val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
+ if (currentKey != nextKey) {
+ if (currentWriter != null) {
+ currentWriter.close()
+ currentWriter = null
+ }
+ currentKey = nextKey.copy()
+ logDebug(s"Writing partition: $currentKey")
+
+ currentWriter = newOutputWriter(currentKey, getPartitionString)
}
- currentKey = nextKey.copy()
- logDebug(s"Writing partition: $currentKey")
-
- currentWriter = newOutputWriter(currentKey, getPartitionString)
+ currentWriter.writeInternal(sortedIterator.getValue)
+ }
+ if (currentWriter != null) {
+ currentWriter.close()
+ currentWriter = null
}
- currentWriter.writeInternal(sortedIterator.getValue)
- }
- if (currentWriter != null) {
- currentWriter.close()
- currentWriter = null
- }
- commitTask()
- } catch {
- case cause: Throwable =>
- logError("Aborting task.", cause)
- // call failure callbacks first, so we could have a chance to cleanup the writer.
- TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause)
+ commitTask()
+ }(catchBlock = {
if (currentWriter != null) {
currentWriter.close()
}
abortTask()
- throw new SparkException("Task failed while writing rows.", cause)
+ })
+ } catch {
+ case t: Throwable =>
+ throw new SparkException("Task failed while writing rows", t)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
index 797f740dc5..ea843a1013 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
@@ -33,11 +33,11 @@ import org.apache.spark.unsafe.types.UTF8String
private[csv] object CSVInferSchema {
/**
- * Similar to the JSON schema inference
- * 1. Infer type of each row
- * 2. Merge row types to find common type
- * 3. Replace any null types with string type
- */
+ * Similar to the JSON schema inference
+ * 1. Infer type of each row
+ * 2. Merge row types to find common type
+ * 3. Replace any null types with string type
+ */
def infer(
tokenRdd: RDD[Array[String]],
header: Array[String],
@@ -75,9 +75,9 @@ private[csv] object CSVInferSchema {
}
/**
- * Infer type of string field. Given known type Double, and a string "1", there is no
- * point checking if it is an Int, as the final type must be Double or higher.
- */
+ * Infer type of string field. Given known type Double, and a string "1", there is no
+ * point checking if it is an Int, as the final type must be Double or higher.
+ */
def inferField(typeSoFar: DataType, field: String, nullValue: String = ""): DataType = {
if (field == null || field.isEmpty || field == nullValue) {
typeSoFar
@@ -142,9 +142,9 @@ private[csv] object CSVInferSchema {
private val numericPrecedence: IndexedSeq[DataType] = HiveTypeCoercion.numericPrecedence
/**
- * Copied from internal Spark api
- * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]]
- */
+ * Copied from internal Spark api
+ * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]]
+ */
val findTightestCommonType: (DataType, DataType) => Option[DataType] = {
case (t1, t2) if t1 == t2 => Some(t1)
case (NullType, t1) => Some(t1)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index 95de02cf5c..7b9d3b605a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -22,8 +22,7 @@ import java.nio.charset.StandardCharsets
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.datasources.{CompressionCodecs, ParseModes}
-private[sql] class CSVOptions(
- @transient private val parameters: Map[String, String])
+private[sql] class CSVOptions(@transient private val parameters: Map[String, String])
extends Logging with Serializable {
private def getChar(paramName: String, default: Char): Char = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala
index 7cf1b4c662..c3d863f547 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala
@@ -25,11 +25,11 @@ import com.univocity.parsers.csv.{CsvParser, CsvParserSettings, CsvWriter, CsvWr
import org.apache.spark.internal.Logging
/**
- * Read and parse CSV-like input
- *
- * @param params Parameters object
- * @param headers headers for the columns
- */
+ * Read and parse CSV-like input
+ *
+ * @param params Parameters object
+ * @param headers headers for the columns
+ */
private[sql] abstract class CsvReader(params: CSVOptions, headers: Seq[String]) {
protected lazy val parser: CsvParser = {
@@ -47,6 +47,7 @@ private[sql] abstract class CsvReader(params: CSVOptions, headers: Seq[String])
settings.setMaxColumns(params.maxColumns)
settings.setNullValue(params.nullValue)
settings.setMaxCharsPerColumn(params.maxCharsPerColumn)
+ settings.setParseUnescapedQuotesUntilDelimiter(true)
if (headers != null) settings.setHeaders(headers: _*)
new CsvParser(settings)
@@ -54,11 +55,11 @@ private[sql] abstract class CsvReader(params: CSVOptions, headers: Seq[String])
}
/**
- * Converts a sequence of string to CSV string
- *
- * @param params Parameters object for configuration
- * @param headers headers for columns
- */
+ * Converts a sequence of string to CSV string
+ *
+ * @param params Parameters object for configuration
+ * @param headers headers for columns
+ */
private[sql] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) extends Logging {
private val writerSettings = new CsvWriterSettings
private val format = writerSettings.getFormat
@@ -90,18 +91,18 @@ private[sql] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) exten
}
/**
- * Parser for parsing a line at a time. Not efficient for bulk data.
- *
- * @param params Parameters object
- */
+ * Parser for parsing a line at a time. Not efficient for bulk data.
+ *
+ * @param params Parameters object
+ */
private[sql] class LineCsvReader(params: CSVOptions)
extends CsvReader(params, null) {
/**
- * parse a line
- *
- * @param line a String with no newline at the end
- * @return array of strings where each string is a field in the CSV record
- */
+ * parse a line
+ *
+ * @param line a String with no newline at the end
+ * @return array of strings where each string is a field in the CSV record
+ */
def parseLine(line: String): Array[String] = {
parser.beginParsing(new StringReader(line))
val parsed = parser.parseNext()
@@ -111,12 +112,12 @@ private[sql] class LineCsvReader(params: CSVOptions)
}
/**
- * Parser for parsing lines in bulk. Use this when efficiency is desired.
- *
- * @param iter iterator over lines in the file
- * @param params Parameters object
- * @param headers headers for the columns
- */
+ * Parser for parsing lines in bulk. Use this when efficiency is desired.
+ *
+ * @param iter iterator over lines in the file
+ * @param params Parameters object
+ * @param headers headers for the columns
+ */
private[sql] class BulkCsvReader(
iter: Iterator[String],
params: CSVOptions,
@@ -128,9 +129,9 @@ private[sql] class BulkCsvReader(
private var nextRecord = parser.parseNext()
/**
- * get the next parsed line.
- * @return array of strings where each string is a field in the CSV record
- */
+ * get the next parsed line.
+ * @return array of strings where each string is a field in the CSV record
+ */
override def next(): Array[String] = {
val curRecord = nextRecord
if(curRecord != null) {
@@ -146,11 +147,11 @@ private[sql] class BulkCsvReader(
}
/**
- * A Reader that "reads" from a sequence of lines. Spark's textFile method removes newlines at
- * end of each line Univocity parser requires a Reader that provides access to the data to be
- * parsed and needs the newlines to be present
- * @param iter iterator over RDD[String]
- */
+ * A Reader that "reads" from a sequence of lines. Spark's textFile method removes newlines at
+ * end of each line Univocity parser requires a Reader that provides access to the data to be
+ * parsed and needs the newlines to be present
+ * @param iter iterator over RDD[String]
+ */
private class StringIteratorReader(val iter: Iterator[String]) extends java.io.Reader {
private var next: Long = 0
@@ -159,9 +160,9 @@ private class StringIteratorReader(val iter: Iterator[String]) extends java.io.R
private var str: String = null // current string from iter
/**
- * fetch next string from iter, if done with current one
- * pretend there is a new line at the end of every string we get from from iter
- */
+ * fetch next string from iter, if done with current one
+ * pretend there is a new line at the end of every string we get from from iter
+ */
private def refill(): Unit = {
if (length == next) {
if (iter.hasNext) {
@@ -175,8 +176,8 @@ private class StringIteratorReader(val iter: Iterator[String]) extends java.io.R
}
/**
- * read the next character, if at end of string pretend there is a new line
- */
+ * read the next character, if at end of string pretend there is a new line
+ */
override def read(): Int = {
refill()
if (next >= length) {
@@ -189,8 +190,8 @@ private class StringIteratorReader(val iter: Iterator[String]) extends java.io.R
}
/**
- * read from str into cbuf
- */
+ * read from str into cbuf
+ */
override def read(cbuf: Array[Char], off: Int, len: Int): Int = {
refill()
var n = 0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
index 5501015775..54fb03b6d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.csv
import scala.util.control.NonFatal
-import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.{NullWritable, Text}
import org.apache.hadoop.mapreduce.RecordWriter
import org.apache.hadoop.mapreduce.TaskAttemptContext
@@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
@@ -41,22 +42,18 @@ object CSVRelation extends Logging {
firstLine: String,
params: CSVOptions): RDD[Array[String]] = {
// If header is set, make sure firstLine is materialized before sending to executors.
- file.mapPartitionsWithIndex({
- case (split, iter) => new BulkCsvReader(
+ file.mapPartitions { iter =>
+ new BulkCsvReader(
if (params.headerFlag) iter.filterNot(_ == firstLine) else iter,
params,
headers = header)
- }, true)
+ }
}
- def parseCsv(
- tokenizedRDD: RDD[Array[String]],
+ def csvParser(
schema: StructType,
requiredColumns: Array[String],
- inputs: Seq[FileStatus],
- sqlContext: SQLContext,
- params: CSVOptions): RDD[InternalRow] = {
-
+ params: CSVOptions): Array[String] => Option[InternalRow] = {
val schemaFields = schema.fields
val requiredFields = StructType(requiredColumns.map(schema(_))).fields
val safeRequiredFields = if (params.dropMalformed) {
@@ -74,7 +71,8 @@ object CSVRelation extends Logging {
}
val requiredSize = requiredFields.length
val row = new GenericMutableRow(requiredSize)
- tokenizedRDD.flatMap { tokens =>
+
+ (tokens: Array[String]) => {
if (params.dropMalformed && schemaFields.length != tokens.length) {
logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}")
None
@@ -118,6 +116,33 @@ object CSVRelation extends Logging {
}
}
}
+
+ def parseCsv(
+ tokenizedRDD: RDD[Array[String]],
+ schema: StructType,
+ requiredColumns: Array[String],
+ options: CSVOptions): RDD[InternalRow] = {
+ val parser = csvParser(schema, requiredColumns, options)
+ tokenizedRDD.flatMap(parser(_).toSeq)
+ }
+
+ // Skips the header line of each file if the `header` option is set to true.
+ def dropHeaderLine(
+ file: PartitionedFile, lines: Iterator[String], csvOptions: CSVOptions): Unit = {
+ // TODO What if the first partitioned file consists of only comments and empty lines?
+ if (csvOptions.headerFlag && file.start == 0) {
+ val nonEmptyLines = if (csvOptions.isCommentSet) {
+ val commentPrefix = csvOptions.comment.toString
+ lines.dropWhile { line =>
+ line.trim.isEmpty || line.trim.startsWith(commentPrefix)
+ }
+ } else {
+ lines.dropWhile(_.trim.isEmpty)
+ }
+
+ if (nonEmptyLines.hasNext) nonEmptyLines.drop(1)
+ }
+ }
}
private[sql] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
index 54e4c1a2c9..06a371b88b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
@@ -19,25 +19,27 @@ package org.apache.spark.sql.execution.datasources.csv
import java.nio.charset.{Charset, StandardCharsets}
+import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.FileStatus
import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.mapred.TextInputFormat
-import org.apache.hadoop.mapreduce.Job
+import org.apache.hadoop.mapreduce._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
-import org.apache.spark.sql.execution.datasources.CompressionCodecs
+import org.apache.spark.sql.catalyst.expressions.{JoinedRow, UnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.execution.datasources.{CompressionCodecs, HadoopFileLinesReader, PartitionedFile}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.util.SerializableConfiguration
import org.apache.spark.util.collection.BitSet
/**
- * Provides access to CSV data from pure SQL statements.
- */
+ * Provides access to CSV data from pure SQL statements.
+ */
class DefaultSource extends FileFormat with DataSourceRegister {
override def shortName(): String = "csv"
@@ -91,39 +93,46 @@ class DefaultSource extends FileFormat with DataSourceRegister {
new CSVOutputWriterFactory(csvOptions)
}
- /**
- * This supports to eliminate unneeded columns before producing an RDD
- * containing all of its tuples as Row objects. This reads all the tokens of each line
- * and then drop unneeded tokens without casting and type-checking by mapping
- * both the indices produced by `requiredColumns` and the ones of tokens.
- */
- override def buildInternalScan(
+ override def buildReader(
sqlContext: SQLContext,
dataSchema: StructType,
- requiredColumns: Array[String],
- filters: Array[Filter],
- bucketSet: Option[BitSet],
- inputFiles: Seq[FileStatus],
- broadcastedConf: Broadcast[SerializableConfiguration],
- options: Map[String, String]): RDD[InternalRow] = {
- // TODO: Filter before calling buildInternalScan.
- val csvFiles = inputFiles.filterNot(_.getPath.getName startsWith "_")
-
+ partitionSchema: StructType,
+ requiredSchema: StructType,
+ filters: Seq[Filter],
+ options: Map[String, String]): (PartitionedFile) => Iterator[InternalRow] = {
val csvOptions = new CSVOptions(options)
- val pathsString = csvFiles.map(_.getPath.toUri.toString)
- val header = dataSchema.fields.map(_.name)
- val tokenizedRdd = tokenRdd(sqlContext, csvOptions, header, pathsString)
- val rows = CSVRelation.parseCsv(
- tokenizedRdd, dataSchema, requiredColumns, csvFiles, sqlContext, csvOptions)
-
- val requiredDataSchema = StructType(requiredColumns.map(c => dataSchema.find(_.name == c).get))
- rows.mapPartitions { iterator =>
- val unsafeProjection = UnsafeProjection.create(requiredDataSchema)
- iterator.map(unsafeProjection)
+ val headers = requiredSchema.fields.map(_.name)
+
+ val conf = new Configuration(sqlContext.sparkContext.hadoopConfiguration)
+ val broadcastedConf = sqlContext.sparkContext.broadcast(new SerializableConfiguration(conf))
+
+ (file: PartitionedFile) => {
+ val lineIterator = {
+ val conf = broadcastedConf.value.value
+ new HadoopFileLinesReader(file, conf).map { line =>
+ new String(line.getBytes, 0, line.getLength, csvOptions.charset)
+ }
+ }
+
+ CSVRelation.dropHeaderLine(file, lineIterator, csvOptions)
+
+ val unsafeRowIterator = {
+ val tokenizedIterator = new BulkCsvReader(lineIterator, csvOptions, headers)
+ val parser = CSVRelation.csvParser(dataSchema, requiredSchema.fieldNames, csvOptions)
+ tokenizedIterator.flatMap(parser(_).toSeq)
+ }
+
+ // Appends partition values
+ val fullOutput = requiredSchema.toAttributes ++ partitionSchema.toAttributes
+ val joinedRow = new JoinedRow()
+ val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput)
+
+ unsafeRowIterator.map { dataRow =>
+ appendPartitionColumns(joinedRow(dataRow, file.partitionValues))
+ }
}
}
-
private def baseRdd(
sqlContext: SQLContext,
options: CSVOptions,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
index 24923bbb10..2e88d588be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
@@ -51,11 +51,11 @@ case class DescribeCommand(
}
/**
- * Used to represent the operation of create table using a data source.
+ * Used to represent the operation of create table using a data source.
*
- * @param allowExisting If it is true, we will do nothing when the table already exists.
- * If it is false, an exception will be thrown
- */
+ * @param allowExisting If it is true, we will do nothing when the table already exists.
+ * If it is false, an exception will be thrown
+ */
case class CreateTableUsing(
tableIdent: TableIdentifier,
userSpecifiedSchema: Option[StructType],
@@ -107,7 +107,7 @@ case class CreateTempTableUsing(
sqlContext.sessionState.catalog.createTempTable(
tableIdent.table,
Dataset.ofRows(sqlContext, LogicalRelation(dataSource.resolveRelation())).logicalPlan,
- ignoreIfExists = true)
+ overrideIfExists = true)
Seq.empty[Row]
}
@@ -138,7 +138,7 @@ case class CreateTempTableUsingAsSelect(
sqlContext.sessionState.catalog.createTempTable(
tableIdent.table,
Dataset.ofRows(sqlContext, LogicalRelation(result)).logicalPlan,
- ignoreIfExists = true)
+ overrideIfExists = true)
Seq.empty[Row]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index b7ff5f7242..065c8572b0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -251,12 +251,12 @@ object JdbcUtils extends Logging {
def schemaString(df: DataFrame, url: String): String = {
val sb = new StringBuilder()
val dialect = JdbcDialects.get(url)
- df.schema.fields foreach { field => {
+ df.schema.fields foreach { field =>
val name = field.name
val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition
val nullable = if (field.nullable) "" else "NOT NULL"
sb.append(s", $name $typ $nullable")
- }}
+ }
if (sb.length < 2) "" else sb.substring(2)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
index 945ed2c211..8e8238a594 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
@@ -25,7 +25,6 @@ import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
-
private[sql] object InferSchema {
/**
@@ -135,14 +134,20 @@ private[sql] object InferSchema {
// when we see a Java BigInteger, we use DecimalType.
case BIG_INTEGER | BIG_DECIMAL =>
val v = parser.getDecimalValue
- DecimalType(v.precision(), v.scale())
- case FLOAT | DOUBLE =>
- if (configOptions.floatAsBigDecimal) {
- val v = parser.getDecimalValue
- DecimalType(v.precision(), v.scale())
+ if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) {
+ DecimalType(Math.max(v.precision(), v.scale()), v.scale())
} else {
DoubleType
}
+ case FLOAT | DOUBLE if configOptions.prefersDecimal =>
+ val v = parser.getDecimalValue
+ if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) {
+ DecimalType(Math.max(v.precision(), v.scale()), v.scale())
+ } else {
+ DoubleType
+ }
+ case FLOAT | DOUBLE =>
+ DoubleType
}
case VALUE_TRUE | VALUE_FALSE => BooleanType
@@ -251,6 +256,14 @@ private[sql] object InferSchema {
case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
+ // The case that given `DecimalType` is capable of given `IntegralType` is handled in
+ // `findTightestCommonTypeOfTwo`. Both cases below will be executed only when
+ // the given `DecimalType` is not capable of the given `IntegralType`.
+ case (t1: IntegralType, t2: DecimalType) =>
+ compatibleType(DecimalType.forType(t1), t2)
+ case (t1: DecimalType, t2: IntegralType) =>
+ compatibleType(t1, DecimalType.forType(t2))
+
// strings and every string is a Json object.
case (_, _) => StringType
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala
index c0ad9efcb7..66f1126fb9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala
@@ -35,8 +35,8 @@ private[sql] class JSONOptions(
parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
val primitivesAsString =
parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false)
- val floatAsBigDecimal =
- parameters.get("floatAsBigDecimal").map(_.toBoolean).getOrElse(false)
+ val prefersDecimal =
+ parameters.get("prefersDecimal").map(_.toBoolean).getOrElse(false)
val allowComments =
parameters.get("allowComments").map(_.toBoolean).getOrElse(false)
val allowUnquotedFieldNames =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
index 3bf0af0efa..7364a1dc06 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.json
import java.io.CharArrayWriter
import com.fasterxml.jackson.core.JsonFactory
+import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{LongWritable, NullWritable, Text}
import org.apache.hadoop.mapred.{JobConf, TextInputFormat}
@@ -27,17 +28,16 @@ import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
-import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.catalyst.expressions.JoinedRow
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration
-import org.apache.spark.util.collection.BitSet
class DefaultSource extends FileFormat with DataSourceRegister {
@@ -91,32 +91,37 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
}
- override def buildInternalScan(
+ override def buildReader(
sqlContext: SQLContext,
dataSchema: StructType,
- requiredColumns: Array[String],
- filters: Array[Filter],
- bucketSet: Option[BitSet],
- inputFiles: Seq[FileStatus],
- broadcastedConf: Broadcast[SerializableConfiguration],
- options: Map[String, String]): RDD[InternalRow] = {
- // TODO: Filter files for all formats before calling buildInternalScan.
- val jsonFiles = inputFiles.filterNot(_.getPath.getName startsWith "_")
+ partitionSchema: StructType,
+ requiredSchema: StructType,
+ filters: Seq[Filter],
+ options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = {
+ val conf = new Configuration(sqlContext.sparkContext.hadoopConfiguration)
+ val broadcastedConf =
+ sqlContext.sparkContext.broadcast(new SerializableConfiguration(conf))
val parsedOptions: JSONOptions = new JSONOptions(options)
- val requiredDataSchema = StructType(requiredColumns.map(dataSchema(_)))
- val columnNameOfCorruptRecord =
- parsedOptions.columnNameOfCorruptRecord
- .getOrElse(sqlContext.conf.columnNameOfCorruptRecord)
- val rows = JacksonParser.parse(
- createBaseRdd(sqlContext, jsonFiles),
- requiredDataSchema,
- columnNameOfCorruptRecord,
- parsedOptions)
-
- rows.mapPartitions { iterator =>
- val unsafeProjection = UnsafeProjection.create(requiredDataSchema)
- iterator.map(unsafeProjection)
+ val columnNameOfCorruptRecord = parsedOptions.columnNameOfCorruptRecord
+ .getOrElse(sqlContext.conf.columnNameOfCorruptRecord)
+
+ val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes
+ val joinedRow = new JoinedRow()
+
+ file => {
+ val lines = new HadoopFileLinesReader(file, broadcastedConf.value.value).map(_.toString)
+
+ val rows = JacksonParser.parseJson(
+ lines,
+ requiredSchema,
+ columnNameOfCorruptRecord,
+ parsedOptions)
+
+ val appendPartitionColumns = GenerateUnsafeProjection.generate(fullSchema, fullSchema)
+ rows.map { row =>
+ appendPartitionColumns(joinedRow(row, file.partitionValues))
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala
index 00c14adf07..aeee2600a1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala
@@ -54,9 +54,9 @@ object JacksonParser extends Logging {
* with an array.
*/
def convertRootField(
- factory: JsonFactory,
- parser: JsonParser,
- schema: DataType): Any = {
+ factory: JsonFactory,
+ parser: JsonParser,
+ schema: DataType): Any = {
import com.fasterxml.jackson.core.JsonToken._
(parser.getCurrentToken, schema) match {
case (START_ARRAY, st: StructType) =>
@@ -250,7 +250,7 @@ object JacksonParser extends Logging {
new GenericArrayData(values.toArray)
}
- private def parseJson(
+ def parseJson(
input: Iterator[String],
schema: StructType,
columnNameOfCorruptRecords: String,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala
deleted file mode 100644
index ecadb9e7c6..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala
+++ /dev/null
@@ -1,88 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.datasources.parquet
-
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.Path
-import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext}
-import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter
-import org.apache.parquet.Log
-import org.apache.parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat}
-import org.apache.parquet.hadoop.util.ContextUtil
-
-/**
- * An output committer for writing Parquet files. In stead of writing to the `_temporary` folder
- * like what [[ParquetOutputCommitter]] does, this output committer writes data directly to the
- * destination folder. This can be useful for data stored in S3, where directory operations are
- * relatively expensive.
- *
- * To enable this output committer, users may set the "spark.sql.parquet.output.committer.class"
- * property via Hadoop [[Configuration]]. Not that this property overrides
- * "spark.sql.sources.outputCommitterClass".
- *
- * *NOTE*
- *
- * NEVER use [[DirectParquetOutputCommitter]] when appending data, because currently there's
- * no safe way undo a failed appending job (that's why both `abortTask()` and `abortJob()` are
- * left empty).
- */
-private[datasources] class DirectParquetOutputCommitter(
- outputPath: Path, context: TaskAttemptContext)
- extends ParquetOutputCommitter(outputPath, context) {
- val LOG = Log.getLog(classOf[ParquetOutputCommitter])
-
- override def getWorkPath: Path = outputPath
- override def abortTask(taskContext: TaskAttemptContext): Unit = {}
- override def commitTask(taskContext: TaskAttemptContext): Unit = {}
- override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = true
- override def setupJob(jobContext: JobContext): Unit = {}
- override def setupTask(taskContext: TaskAttemptContext): Unit = {}
-
- override def commitJob(jobContext: JobContext) {
- val configuration = ContextUtil.getConfiguration(jobContext)
- val fileSystem = outputPath.getFileSystem(configuration)
-
- if (configuration.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, true)) {
- try {
- val outputStatus = fileSystem.getFileStatus(outputPath)
- val footers = ParquetFileReader.readAllFootersInParallel(configuration, outputStatus)
- try {
- ParquetFileWriter.writeMetadataFile(configuration, outputPath, footers)
- } catch { case e: Exception =>
- LOG.warn("could not write summary file for " + outputPath, e)
- val metadataPath = new Path(outputPath, ParquetFileWriter.PARQUET_METADATA_FILE)
- if (fileSystem.exists(metadataPath)) {
- fileSystem.delete(metadataPath, true)
- }
- }
- } catch {
- case e: Exception => LOG.warn("could not write summary file for " + outputPath, e)
- }
- }
-
- if (configuration.getBoolean("mapreduce.fileoutputcommitter.marksuccessfuljobs", true)) {
- try {
- val successPath = new Path(outputPath, FileOutputCommitter.SUCCEEDED_FILE_NAME)
- fileSystem.create(successPath).close()
- } catch {
- case e: Exception => LOG.warn("could not write success file for " + outputPath, e)
- }
- }
- }
-}
-
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala
new file mode 100644
index 0000000000..00352f23ae
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.parquet
+
+import org.apache.parquet.hadoop.metadata.CompressionCodecName
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.internal.SQLConf
+
+/**
+ * Options for the Parquet data source.
+ */
+class ParquetOptions(
+ @transient private val parameters: Map[String, String],
+ @transient private val sqlConf: SQLConf)
+ extends Logging with Serializable {
+
+ import ParquetOptions._
+
+ /**
+ * Compression codec to use. By default use the value specified in SQLConf.
+ * Acceptable values are defined in [[shortParquetCompressionCodecNames]].
+ */
+ val compressionCodec: String = {
+ val codecName = parameters.getOrElse("compression", sqlConf.parquetCompressionCodec).toLowerCase
+ if (!shortParquetCompressionCodecNames.contains(codecName)) {
+ val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase)
+ throw new IllegalArgumentException(s"Codec [$codecName] " +
+ s"is not available. Available codecs are ${availableCodecs.mkString(", ")}.")
+ }
+ shortParquetCompressionCodecNames(codecName).name()
+ }
+}
+
+
+object ParquetOptions {
+ // The parquet compression short names
+ private val shortParquetCompressionCodecNames = Map(
+ "none" -> CompressionCodecName.UNCOMPRESSED,
+ "uncompressed" -> CompressionCodecName.UNCOMPRESSED,
+ "snappy" -> CompressionCodecName.SNAPPY,
+ "gzip" -> CompressionCodecName.GZIP,
+ "lzo" -> CompressionCodecName.LZO)
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
index d6b84be267..b91e892f8f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
@@ -18,33 +18,27 @@
package org.apache.spark.sql.execution.datasources.parquet
import java.net.URI
-import java.util.{List => JList}
import java.util.logging.{Logger => JLogger}
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.{Failure, Try}
-import scala.util.control.NonFatal
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
-import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit}
-import org.apache.hadoop.mapreduce.task.{JobContextImpl, TaskAttemptContextImpl}
+import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
import org.apache.parquet.{Log => ApacheParquetLog}
import org.apache.parquet.filter2.compat.FilterCompat
import org.apache.parquet.filter2.predicate.FilterApi
import org.apache.parquet.hadoop._
-import org.apache.parquet.hadoop.metadata.CompressionCodecName
import org.apache.parquet.hadoop.util.ContextUtil
import org.apache.parquet.schema.MessageType
import org.slf4j.bridge.SLF4JBridgeHandler
-import org.apache.spark.{Partition => SparkPartition, SparkException}
-import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
-import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.JoinedRow
@@ -53,9 +47,8 @@ import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
-import org.apache.spark.sql.types.{DataType, StructType}
-import org.apache.spark.util.{SerializableConfiguration, Utils}
-import org.apache.spark.util.collection.BitSet
+import org.apache.spark.sql.types.{AtomicType, DataType, StructType}
+import org.apache.spark.util.SerializableConfiguration
private[sql] class DefaultSource
extends FileFormat
@@ -75,14 +68,9 @@ private[sql] class DefaultSource
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
- val conf = ContextUtil.getConfiguration(job)
+ val parquetOptions = new ParquetOptions(options, sqlContext.sessionState.conf)
- // SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible
- val committerClassName = conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key)
- if (committerClassName == "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") {
- conf.set(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key,
- classOf[DirectParquetOutputCommitter].getCanonicalName)
- }
+ val conf = ContextUtil.getConfiguration(job)
val committerClass =
conf.getClass(
@@ -92,24 +80,11 @@ private[sql] class DefaultSource
if (conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) == null) {
logInfo("Using default output committer for Parquet: " +
- classOf[ParquetOutputCommitter].getCanonicalName)
+ classOf[ParquetOutputCommitter].getCanonicalName)
} else {
logInfo("Using user defined output committer for Parquet: " + committerClass.getCanonicalName)
}
- val compressionCodec: Option[String] = options
- .get("compression")
- .map { codecName =>
- // Validate if given compression codec is supported or not.
- val shortParquetCompressionCodecNames = ParquetRelation.shortParquetCompressionCodecNames
- if (!shortParquetCompressionCodecNames.contains(codecName.toLowerCase)) {
- val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase)
- throw new IllegalArgumentException(s"Codec [$codecName] " +
- s"is not available. Available codecs are ${availableCodecs.mkString(", ")}.")
- }
- codecName.toLowerCase
- }
-
conf.setClass(
SQLConf.OUTPUT_COMMITTER_CLASS.key,
committerClass,
@@ -144,14 +119,7 @@ private[sql] class DefaultSource
sqlContext.conf.writeLegacyParquetFormat.toString)
// Sets compression scheme
- conf.set(
- ParquetOutputFormat.COMPRESSION,
- ParquetRelation
- .shortParquetCompressionCodecNames
- .getOrElse(
- compressionCodec
- .getOrElse(sqlContext.conf.parquetCompressionCodec.toLowerCase),
- CompressionCodecName.UNCOMPRESSED).name())
+ conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodec)
new OutputWriterFactory {
override def newInstance(
@@ -277,37 +245,35 @@ private[sql] class DefaultSource
}
/**
- * Returns a function that can be used to read a single file in as an Iterator of InternalRow.
- *
- * @param partitionSchema The schema of the partition column row that will be present in each
- * PartitionedFile. These columns should be prepended to the rows that
- * are produced by the iterator.
- * @param dataSchema The schema of the data that should be output for each row. This may be a
- * subset of the columns that are present in the file if column pruning has
- * occurred.
- * @param filters A set of filters than can optionally be used to reduce the number of rows output
- * @param options A set of string -> string configuration options.
- * @return
+ * Returns whether the reader will return the rows as batch or not.
*/
+ override def supportBatch(sqlContext: SQLContext, schema: StructType): Boolean = {
+ val conf = SQLContext.getActive().get.conf
+ conf.parquetVectorizedReaderEnabled && conf.wholeStageEnabled &&
+ schema.length <= conf.wholeStageMaxNumFields &&
+ schema.forall(_.dataType.isInstanceOf[AtomicType])
+ }
+
override def buildReader(
sqlContext: SQLContext,
- partitionSchema: StructType,
dataSchema: StructType,
+ partitionSchema: StructType,
+ requiredSchema: StructType,
filters: Seq[Filter],
options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = {
val parquetConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration)
parquetConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[CatalystReadSupport].getName)
parquetConf.set(
CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA,
- CatalystSchemaConverter.checkFieldNames(dataSchema).json)
+ CatalystSchemaConverter.checkFieldNames(requiredSchema).json)
parquetConf.set(
CatalystWriteSupport.SPARK_ROW_SCHEMA,
- CatalystSchemaConverter.checkFieldNames(dataSchema).json)
+ CatalystSchemaConverter.checkFieldNames(requiredSchema).json)
// We want to clear this temporary metadata from saving into Parquet file.
// This metadata is only useful for detecting optional columns when pushdowning filters.
val dataSchemaToWrite = StructType.removeMetadata(StructType.metadataKeyForOptionalField,
- dataSchema).asInstanceOf[StructType]
+ requiredSchema).asInstanceOf[StructType]
CatalystWriteSupport.setSchema(dataSchemaToWrite, parquetConf)
// Sets flags for `CatalystSchemaConverter`
@@ -318,13 +284,17 @@ private[sql] class DefaultSource
SQLConf.PARQUET_INT96_AS_TIMESTAMP.key,
sqlContext.conf.getConf(SQLConf.PARQUET_INT96_AS_TIMESTAMP))
+ // Whole stage codegen (PhysicalRDD) is able to deal with batches directly
+ val returningBatch =
+ supportBatch(sqlContext, StructType(partitionSchema.fields ++ dataSchema.fields))
+
// Try to push down filters when filter push-down is enabled.
val pushed = if (sqlContext.getConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key).toBoolean) {
filters
// Collects all converted Parquet filter predicates. Notice that not all predicates can be
// converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap`
// is used here.
- .flatMap(ParquetFilters.createFilter(dataSchema, _))
+ .flatMap(ParquetFilters.createFilter(requiredSchema, _))
.reduceOption(FilterApi.and)
} else {
None
@@ -336,10 +306,8 @@ private[sql] class DefaultSource
// TODO: if you move this into the closure it reverts to the default values.
// If true, enable using the custom RecordReader for parquet. This only works for
// a subset of the types (no complex types).
- val enableVectorizedParquetReader: Boolean =
- sqlContext.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean
- val enableWholestageCodegen: Boolean =
- sqlContext.getConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key).toBoolean
+ val enableVectorizedParquetReader: Boolean = sqlContext.conf.parquetVectorizedReaderEnabled &&
+ dataSchema.forall(_.dataType.isInstanceOf[AtomicType])
(file: PartitionedFile) => {
assert(file.partitionValues.numFields == partitionSchema.size)
@@ -359,32 +327,27 @@ private[sql] class DefaultSource
val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
val hadoopAttemptContext = new TaskAttemptContextImpl(broadcastedConf.value.value, attemptId)
- val parquetReader = try {
- if (!enableVectorizedParquetReader) sys.error("Vectorized reader turned off.")
+ val parquetReader = if (enableVectorizedParquetReader) {
val vectorizedReader = new VectorizedParquetRecordReader()
vectorizedReader.initialize(split, hadoopAttemptContext)
logDebug(s"Appending $partitionSchema ${file.partitionValues}")
vectorizedReader.initBatch(partitionSchema, file.partitionValues)
- // Whole stage codegen (PhysicalRDD) is able to deal with batches directly
- // TODO: fix column appending
- if (enableWholestageCodegen) {
- logDebug(s"Enabling batch returning")
+ if (returningBatch) {
vectorizedReader.enableReturningBatches()
}
vectorizedReader
- } catch {
- case NonFatal(e) =>
- logDebug(s"Falling back to parquet-mr: $e", e)
- val reader = pushed match {
- case Some(filter) =>
- new ParquetRecordReader[InternalRow](
- new CatalystReadSupport,
- FilterCompat.get(filter, null))
- case _ =>
- new ParquetRecordReader[InternalRow](new CatalystReadSupport)
- }
- reader.initialize(split, hadoopAttemptContext)
- reader
+ } else {
+ logDebug(s"Falling back to parquet-mr")
+ val reader = pushed match {
+ case Some(filter) =>
+ new ParquetRecordReader[InternalRow](
+ new CatalystReadSupport,
+ FilterCompat.get(filter, null))
+ case _ =>
+ new ParquetRecordReader[InternalRow](new CatalystReadSupport)
+ }
+ reader.initialize(split, hadoopAttemptContext)
+ reader
}
val iter = new RecordReaderIterator(parquetReader)
@@ -394,7 +357,7 @@ private[sql] class DefaultSource
enableVectorizedParquetReader) {
iter.asInstanceOf[Iterator[InternalRow]]
} else {
- val fullSchema = dataSchema.toAttributes ++ partitionSchema.toAttributes
+ val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes
val joinedRow = new JoinedRow()
val appendPartitionColumns = GenerateUnsafeProjection.generate(fullSchema, fullSchema)
@@ -406,91 +369,6 @@ private[sql] class DefaultSource
}
}
}
-
- override def buildInternalScan(
- sqlContext: SQLContext,
- dataSchema: StructType,
- requiredColumns: Array[String],
- filters: Array[Filter],
- bucketSet: Option[BitSet],
- allFiles: Seq[FileStatus],
- broadcastedConf: Broadcast[SerializableConfiguration],
- options: Map[String, String]): RDD[InternalRow] = {
- val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA)
- val parquetFilterPushDown = sqlContext.conf.parquetFilterPushDown
- val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString
- val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp
-
- // Parquet row group size. We will use this value as the value for
- // mapreduce.input.fileinputformat.split.minsize and mapred.min.split.size if the value
- // of these flags are smaller than the parquet row group size.
- val parquetBlockSize = ParquetOutputFormat.getLongBlockSize(broadcastedConf.value.value)
-
- // Create the function to set variable Parquet confs at both driver and executor side.
- val initLocalJobFuncOpt =
- ParquetRelation.initializeLocalJobFunc(
- requiredColumns,
- filters,
- dataSchema,
- parquetBlockSize,
- useMetadataCache,
- parquetFilterPushDown,
- assumeBinaryIsString,
- assumeInt96IsTimestamp) _
-
- val inputFiles = splitFiles(allFiles).data.toArray
-
- // Create the function to set input paths at the driver side.
- val setInputPaths =
- ParquetRelation.initializeDriverSideJobFunc(inputFiles, parquetBlockSize) _
-
- Utils.withDummyCallSite(sqlContext.sparkContext) {
- new SqlNewHadoopRDD(
- sqlContext = sqlContext,
- broadcastedConf = broadcastedConf,
- initDriverSideJobFuncOpt = Some(setInputPaths),
- initLocalJobFuncOpt = Some(initLocalJobFuncOpt),
- inputFormatClass = classOf[ParquetInputFormat[InternalRow]],
- valueClass = classOf[InternalRow]) {
-
- val cacheMetadata = useMetadataCache
-
- @transient val cachedStatuses = inputFiles.map { f =>
- // In order to encode the authority of a Path containing special characters such as '/'
- // (which does happen in some S3N credentials), we need to use the string returned by the
- // URI of the path to create a new Path.
- val pathWithEscapedAuthority = escapePathUserInfo(f.getPath)
- new FileStatus(
- f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, f.getModificationTime,
- f.getAccessTime, f.getPermission, f.getOwner, f.getGroup, pathWithEscapedAuthority)
- }.toSeq
-
- private def escapePathUserInfo(path: Path): Path = {
- val uri = path.toUri
- new Path(new URI(
- uri.getScheme, uri.getRawUserInfo, uri.getHost, uri.getPort, uri.getPath,
- uri.getQuery, uri.getFragment))
- }
-
- // Overridden so we can inject our own cached files statuses.
- override def getPartitions: Array[SparkPartition] = {
- val inputFormat = new ParquetInputFormat[InternalRow] {
- override def listStatus(jobContext: JobContext): JList[FileStatus] = {
- if (cacheMetadata) cachedStatuses.asJava else super.listStatus(jobContext)
- }
- }
-
- val jobContext = new JobContextImpl(getConf(isDriverSide = true), jobId)
- val rawSplits = inputFormat.getSplits(jobContext)
-
- Array.tabulate[SparkPartition](rawSplits.size) { i =>
- new SqlNewHadoopPartition(
- id, i, rawSplits.get(i).asInstanceOf[InputSplit with Writable])
- }
- }
- }
- }
- }
}
// NOTE: This class is instantiated and used on executor side only, no need to be serializable.
@@ -911,12 +789,4 @@ private[sql] object ParquetRelation extends Logging {
// should be removed after this issue is fixed.
}
}
-
- // The parquet compression short names
- val shortParquetCompressionCodecNames = Map(
- "none" -> CompressionCodecName.UNCOMPRESSED,
- "uncompressed" -> CompressionCodecName.UNCOMPRESSED,
- "snappy" -> CompressionCodecName.SNAPPY,
- "gzip" -> CompressionCodecName.GZIP,
- "lzo" -> CompressionCodecName.LZO)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala
index 5cfc9e9afa..94ecb7a286 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala
@@ -17,24 +17,20 @@
package org.apache.spark.sql.execution.datasources.text
+import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
-import org.apache.hadoop.io.{LongWritable, NullWritable, Text}
-import org.apache.hadoop.mapred.{JobConf, TextInputFormat}
+import org.apache.hadoop.io.{NullWritable, Text}
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
-import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
-import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
-import org.apache.spark.sql.execution.datasources.CompressionCodecs
+import org.apache.spark.sql.execution.datasources.{CompressionCodecs, HadoopFileLinesReader, PartitionedFile}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.SerializableConfiguration
-import org.apache.spark.util.collection.BitSet
/**
* A data source for reading text files.
@@ -87,43 +83,30 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
}
- override def buildInternalScan(
+ override def buildReader(
sqlContext: SQLContext,
dataSchema: StructType,
- requiredColumns: Array[String],
- filters: Array[Filter],
- bucketSet: Option[BitSet],
- inputFiles: Seq[FileStatus],
- broadcastedConf: Broadcast[SerializableConfiguration],
- options: Map[String, String]): RDD[InternalRow] = {
- verifySchema(dataSchema)
-
- val job = Job.getInstance(sqlContext.sparkContext.hadoopConfiguration)
- val conf = job.getConfiguration
- val paths = inputFiles
- .filterNot(_.getPath.getName startsWith "_")
- .map(_.getPath)
- .sortBy(_.toUri)
-
- if (paths.nonEmpty) {
- FileInputFormat.setInputPaths(job, paths: _*)
+ partitionSchema: StructType,
+ requiredSchema: StructType,
+ filters: Seq[Filter],
+ options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = {
+ val conf = new Configuration(sqlContext.sparkContext.hadoopConfiguration)
+ val broadcastedConf =
+ sqlContext.sparkContext.broadcast(new SerializableConfiguration(conf))
+
+ file => {
+ val unsafeRow = new UnsafeRow(1)
+ val bufferHolder = new BufferHolder(unsafeRow)
+ val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1)
+
+ new HadoopFileLinesReader(file, broadcastedConf.value.value).map { line =>
+ // Writes to an UnsafeRow directly
+ bufferHolder.reset()
+ unsafeRowWriter.write(0, line.getBytes, 0, line.getLength)
+ unsafeRow.setTotalSize(bufferHolder.totalSize())
+ unsafeRow
+ }
}
-
- sqlContext.sparkContext.hadoopRDD(
- conf.asInstanceOf[JobConf], classOf[TextInputFormat], classOf[LongWritable], classOf[Text])
- .mapPartitions { iter =>
- val unsafeRow = new UnsafeRow(1)
- val bufferHolder = new BufferHolder(unsafeRow)
- val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1)
-
- iter.map { case (_, line) =>
- // Writes to an UnsafeRow directly
- bufferHolder.reset()
- unsafeRowWriter.write(0, line.getBytes, 0, line.getLength)
- unsafeRow.setTotalSize(bufferHolder.totalSize())
- unsafeRow
- }
- }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index 5e573b3159..17eae88b49 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -25,7 +25,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.internal.SQLConf
@@ -35,12 +35,38 @@ import org.apache.spark.sql.internal.SQLConf
* Usage:
* {{{
* import org.apache.spark.sql.execution.debug._
- * sql("SELECT key FROM src").debug()
- * dataFrame.typeCheck()
+ * sql("SELECT 1").debug()
+ * sql("SELECT 1").debugCodegen()
* }}}
*/
package object debug {
+ /** Helper function to evade the println() linter. */
+ private def debugPrint(msg: String): Unit = {
+ // scalastyle:off println
+ println(msg)
+ // scalastyle:on println
+ }
+
+ def codegenString(plan: SparkPlan): String = {
+ val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegen]()
+ plan transform {
+ case s: WholeStageCodegen =>
+ codegenSubtrees += s
+ s
+ case s => s
+ }
+ var output = s"Found ${codegenSubtrees.size} WholeStageCodegen subtrees.\n"
+ for ((s, i) <- codegenSubtrees.toSeq.zipWithIndex) {
+ output += s"== Subtree ${i + 1} / ${codegenSubtrees.size} ==\n"
+ output += s
+ output += "\nGenerated code:\n"
+ val (_, source) = s.doCodeGen()
+ output += s"${CodeFormatter.format(source)}\n"
+ }
+ output
+ }
+
/**
* Augments [[SQLContext]] with debug methods.
*/
@@ -51,9 +77,9 @@ package object debug {
}
/**
- * Augments [[DataFrame]]s with debug methods.
+ * Augments [[Dataset]]s with debug methods.
*/
- implicit class DebugQuery(query: DataFrame) extends Logging {
+ implicit class DebugQuery(query: Dataset[_]) extends Logging {
def debug(): Unit = {
val plan = query.queryExecution.executedPlan
val visited = new collection.mutable.HashSet[TreeNodeRef]()
@@ -62,12 +88,20 @@ package object debug {
visited += new TreeNodeRef(s)
DebugNode(s)
}
- logDebug(s"Results returned: ${debugPlan.execute().count()}")
+ debugPrint(s"Results returned: ${debugPlan.execute().count()}")
debugPlan.foreach {
case d: DebugNode => d.dumpStats()
case _ =>
}
}
+
+ /**
+ * Prints to stdout all the generated code found in this plan (i.e. the output of each
+ * WholeStageCodegen subtree).
+ */
+ def debugCodegen(): Unit = {
+ debugPrint(codegenString(query.queryExecution.executedPlan))
+ }
}
private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode with CodegenSupport {
@@ -87,6 +121,7 @@ package object debug {
/**
* A collection of metrics for each column of output.
+ *
* @param elementTypes the actual runtime types for the output. Useful when there are bugs
* causing the wrong data to be projected.
*/
@@ -99,11 +134,11 @@ package object debug {
val columnStats: Array[ColumnMetrics] = Array.fill(child.output.size)(new ColumnMetrics())
def dumpStats(): Unit = {
- logDebug(s"== ${child.simpleString} ==")
- logDebug(s"Tuples output: ${tupleCount.value}")
+ debugPrint(s"== ${child.simpleString} ==")
+ debugPrint(s"Tuples output: ${tupleCount.value}")
child.output.zip(columnStats).foreach { case (attr, metric) =>
val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}")
- logDebug(s" ${attr.name} ${attr.dataType}: $actualDataTypes")
+ debugPrint(s" ${attr.name} ${attr.dataType}: $actualDataTypes")
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index f5b083c216..a8f854136c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -27,12 +27,12 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.util.collection.CompactBuffer
+import org.apache.spark.sql.types.LongType
/**
* Performs an inner hash join of two child relations. When the output RDD of this operator is
* being constructed, a Spark job is asynchronously started to calculate the values for the
- * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed
+ * broadcast relation. This data is then placed in a Spark broadcast variable. The streamed
* relation is not shuffled.
*/
case class BroadcastHashJoin(
@@ -51,10 +51,7 @@ case class BroadcastHashJoin(
override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
override def requiredChildDistribution: Seq[Distribution] = {
- val mode = HashedRelationBroadcastMode(
- canJoinKeyFitWithinLong,
- rewriteKeyExpr(buildKeys),
- buildPlan.output)
+ val mode = HashedRelationBroadcastMode(buildKeys)
buildSide match {
case BuildLeft =>
BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil
@@ -68,37 +65,9 @@ case class BroadcastHashJoin(
val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
streamedPlan.execute().mapPartitions { streamedIter =>
- val joinedRow = new JoinedRow()
- val hashTable = broadcastRelation.value
- TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize)
- val keyGenerator = streamSideKeyGenerator
- val resultProj = createResultProjection
-
- joinType match {
- case Inner =>
- hashJoin(streamedIter, hashTable, numOutputRows)
-
- case LeftOuter =>
- streamedIter.flatMap { currentRow =>
- val rowKey = keyGenerator(currentRow)
- joinedRow.withLeft(currentRow)
- leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows)
- }
-
- case RightOuter =>
- streamedIter.flatMap { currentRow =>
- val rowKey = keyGenerator(currentRow)
- joinedRow.withRight(currentRow)
- rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows)
- }
-
- case LeftSemi =>
- hashSemiJoin(streamedIter, hashTable, numOutputRows)
-
- case x =>
- throw new IllegalArgumentException(
- s"BroadcastHashJoin should not take $x as the JoinType")
- }
+ val hashed = broadcastRelation.value.asReadOnlyCopy()
+ TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize)
+ join(streamedIter, hashed, numOutputRows)
}
}
@@ -115,6 +84,7 @@ case class BroadcastHashJoin(
case Inner => codegenInner(ctx, input)
case LeftOuter | RightOuter => codegenOuter(ctx, input)
case LeftSemi => codegenSemi(ctx, input)
+ case LeftAnti => codegenAnti(ctx, input)
case x =>
throw new IllegalArgumentException(
s"BroadcastHashJoin should not take $x as the JoinType")
@@ -122,8 +92,8 @@ case class BroadcastHashJoin(
}
/**
- * Returns a tuple of Broadcast of HashedRelation and the variable name for it.
- */
+ * Returns a tuple of Broadcast of HashedRelation and the variable name for it.
+ */
private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = {
// create a name for HashedRelation
val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
@@ -132,36 +102,34 @@ case class BroadcastHashJoin(
val clsName = broadcastRelation.value.getClass.getName
ctx.addMutableState(clsName, relationTerm,
s"""
- | $relationTerm = ($clsName) $broadcast.value();
- | incPeakExecutionMemory($relationTerm.getMemorySize());
+ | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy();
+ | incPeakExecutionMemory($relationTerm.estimatedSize());
""".stripMargin)
(broadcastRelation, relationTerm)
}
/**
- * Returns the code for generating join key for stream side, and expression of whether the key
- * has any null in it or not.
- */
+ * Returns the code for generating join key for stream side, and expression of whether the key
+ * has any null in it or not.
+ */
private def genStreamSideJoinKey(
ctx: CodegenContext,
input: Seq[ExprCode]): (ExprCode, String) = {
ctx.currentVars = input
- if (canJoinKeyFitWithinLong) {
+ if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) {
// generate the join key as Long
- val expr = rewriteKeyExpr(streamedKeys).head
- val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx)
+ val ev = streamedKeys.head.gen(ctx)
(ev, ev.isNull)
} else {
// generate the join key as UnsafeRow
- val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output))
- val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr)
+ val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys)
(ev, s"${ev.value}.anyNull()")
}
}
/**
- * Generates the code for variable of build side.
- */
+ * Generates the code for variable of build side.
+ */
private def genBuildSideVars(ctx: CodegenContext, matched: String): Seq[ExprCode] = {
ctx.currentVars = null
ctx.INPUT_ROW = matched
@@ -188,15 +156,14 @@ case class BroadcastHashJoin(
}
/**
- * Generates the code for Inner join.
- */
- private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = {
- val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
- val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ * Generate the (non-equi) condition used to filter joined rows. This is used in Inner, Left Semi
+ * and Left Anti joins.
+ */
+ private def getJoinCondition(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = {
val matched = ctx.freshName("matched")
val buildVars = genBuildSideVars(ctx, matched)
- val numOutput = metricTerm(ctx, "numOutputRows")
-
val checkCondition = if (condition.isDefined) {
val expr = condition.get
// evaluate the variables from build side that used by condition
@@ -212,12 +179,23 @@ case class BroadcastHashJoin(
} else {
""
}
+ (matched, checkCondition, buildVars)
+ }
+
+ /**
+ * Generates the code for Inner join.
+ */
+ private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
val resultVars = buildSide match {
case BuildLeft => buildVars ++ input
case BuildRight => input ++ buildVars
}
- if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
+ if (broadcastRelation.value.keyIsUnique) {
s"""
|// generate join key for stream side
|${keyEv.code}
@@ -232,18 +210,15 @@ case class BroadcastHashJoin(
} else {
ctx.copyResult = true
val matches = ctx.freshName("matches")
- val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
- val i = ctx.freshName("i")
- val size = ctx.freshName("size")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
s"""
|// generate join key for stream side
|${keyEv.code}
|// find matches from HashRelation
- |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
+ |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value});
|if ($matches == null) continue;
- |int $size = $matches.size();
- |for (int $i = 0; $i < $size; $i++) {
- | UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
+ |while ($matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
| $checkCondition
| $numOutput.add(1);
| ${consume(ctx, resultVars)}
@@ -252,10 +227,9 @@ case class BroadcastHashJoin(
}
}
-
/**
- * Generates the code for left or right outer join.
- */
+ * Generates the code for left or right outer join.
+ */
private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
@@ -287,7 +261,7 @@ case class BroadcastHashJoin(
case BuildLeft => buildVars ++ input
case BuildRight => input ++ buildVars
}
- if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
+ if (broadcastRelation.value.keyIsUnique) {
s"""
|// generate join key for stream side
|${keyEv.code}
@@ -306,22 +280,20 @@ case class BroadcastHashJoin(
} else {
ctx.copyResult = true
val matches = ctx.freshName("matches")
- val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
- val i = ctx.freshName("i")
- val size = ctx.freshName("size")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
val found = ctx.freshName("found")
s"""
|// generate join key for stream side
|${keyEv.code}
|// find matches from HashRelation
- |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
- |int $size = $matches != null ? $matches.size() : 0;
+ |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value});
|boolean $found = false;
|// the last iteration of this loop is to emit an empty row if there is no matched rows.
- |for (int $i = 0; $i <= $size; $i++) {
- | UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply($i) : null;
+ |while ($matches != null && $matches.hasNext() || !$found) {
+ | UnsafeRow $matched = $matches != null && $matches.hasNext() ?
+ | (UnsafeRow) $matches.next() : null;
| ${checkCondition.trim}
- | if (!$conditionPassed || ($i == $size && $found)) continue;
+ | if (!$conditionPassed) continue;
| $found = true;
| $numOutput.add(1);
| ${consume(ctx, resultVars)}
@@ -336,27 +308,9 @@ case class BroadcastHashJoin(
private def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
- val matched = ctx.freshName("matched")
- val buildVars = genBuildSideVars(ctx, matched)
+ val (matched, checkCondition, _) = getJoinCondition(ctx, input)
val numOutput = metricTerm(ctx, "numOutputRows")
-
- val checkCondition = if (condition.isDefined) {
- val expr = condition.get
- // evaluate the variables from build side that used by condition
- val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references)
- // filter the output via condition
- ctx.currentVars = input ++ buildVars
- val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx)
- s"""
- |$eval
- |${ev.code}
- |if (${ev.isNull} || !${ev.value}) continue;
- """.stripMargin
- } else {
- ""
- }
-
- if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
+ if (broadcastRelation.value.keyIsUnique) {
s"""
|// generate join key for stream side
|${keyEv.code}
@@ -369,23 +323,19 @@ case class BroadcastHashJoin(
""".stripMargin
} else {
val matches = ctx.freshName("matches")
- val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
- val i = ctx.freshName("i")
- val size = ctx.freshName("size")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
val found = ctx.freshName("found")
s"""
|// generate join key for stream side
|${keyEv.code}
|// find matches from HashRelation
- |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
+ |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value});
|if ($matches == null) continue;
- |int $size = $matches.size();
|boolean $found = false;
- |for (int $i = 0; $i < $size; $i++) {
- | UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
+ |while (!$found && $matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
| $checkCondition
| $found = true;
- | break;
|}
|if (!$found) continue;
|$numOutput.add(1);
@@ -393,4 +343,57 @@ case class BroadcastHashJoin(
""".stripMargin
}
}
+
+ /**
+ * Generates the code for anti join.
+ */
+ private def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, _) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ if (broadcastRelation.value.keyIsUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// Check if the key has nulls.
+ |if (!($anyNull)) {
+ | // Check if the HashedRelation exists.
+ | UnsafeRow $matched = (UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ | if ($matched != null) {
+ | // Evaluate the condition.
+ | $checkCondition
+ | }
+ |}
+ |$numOutput.add(1);
+ |${consume(ctx, input)}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+ val found = ctx.freshName("found")
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// Check if the key has nulls.
+ |if (!($anyNull)) {
+ | // Check if the HashedRelation exists.
+ | $iteratorCls $matches = ($iteratorCls)$relationTerm.get(${keyEv.value});
+ | if ($matches != null) {
+ | // Evaluate the condition.
+ | boolean $found = false;
+ | while (!$found && $matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
+ | $checkCondition
+ | $found = true;
+ | }
+ | if ($found) continue;
+ | }
+ |}
+ |$numOutput.add(1);
+ |${consume(ctx, input)}
+ """.stripMargin
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
index 4143e944e5..4ba710c10a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
@@ -73,7 +73,7 @@ case class BroadcastNestedLoopJoin(
left.output.map(_.withNullability(true)) ++ right.output
case FullOuter =>
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
- case LeftSemi =>
+ case LeftExistence(_) =>
left.output
case x =>
throw new IllegalArgumentException(
@@ -175,8 +175,11 @@ case class BroadcastNestedLoopJoin(
* The implementation for these joins:
*
* LeftSemi with BuildRight
+ * Anti with BuildRight
*/
- private def leftSemiJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
+ private def leftExistenceJoin(
+ relation: Broadcast[Array[InternalRow]],
+ exists: Boolean): RDD[InternalRow] = {
assert(buildSide == BuildRight)
streamed.execute().mapPartitionsInternal { streamedIter =>
val buildRows = relation.value
@@ -184,10 +187,12 @@ case class BroadcastNestedLoopJoin(
if (condition.isDefined) {
streamedIter.filter(l =>
- buildRows.exists(r => boundCondition(joinedRow(l, r)))
+ buildRows.exists(r => boundCondition(joinedRow(l, r))) == exists
)
+ } else if (buildRows.nonEmpty == exists) {
+ streamedIter
} else {
- streamedIter.filter(r => !buildRows.isEmpty)
+ Iterator.empty
}
}
}
@@ -199,6 +204,7 @@ case class BroadcastNestedLoopJoin(
* RightOuter with BuildRight
* FullOuter
* LeftSemi with BuildLeft
+ * Anti with BuildLeft
*/
private def defaultJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
/** All rows that either match both-way, or rows from streamed joined with nulls. */
@@ -236,7 +242,27 @@ case class BroadcastNestedLoopJoin(
}
i += 1
}
- return sparkContext.makeRDD(buf.toSeq)
+ return sparkContext.makeRDD(buf)
+ }
+
+ val notMatchedBroadcastRows: Seq[InternalRow] = {
+ val nulls = new GenericMutableRow(streamed.output.size)
+ val buf: CompactBuffer[InternalRow] = new CompactBuffer()
+ var i = 0
+ val buildRows = relation.value
+ val joinedRow = new JoinedRow
+ joinedRow.withLeft(nulls)
+ while (i < buildRows.length) {
+ if (!matchedBroadcastRows.get(i)) {
+ buf += joinedRow.withRight(buildRows(i)).copy()
+ }
+ i += 1
+ }
+ buf
+ }
+
+ if (joinType == LeftAnti) {
+ return sparkContext.makeRDD(notMatchedBroadcastRows)
}
val matchedStreamRows = streamRdd.mapPartitionsInternal { streamedIter =>
@@ -264,22 +290,6 @@ case class BroadcastNestedLoopJoin(
}
}
- val notMatchedBroadcastRows: Seq[InternalRow] = {
- val nulls = new GenericMutableRow(streamed.output.size)
- val buf: CompactBuffer[InternalRow] = new CompactBuffer()
- var i = 0
- val buildRows = relation.value
- val joinedRow = new JoinedRow
- joinedRow.withLeft(nulls)
- while (i < buildRows.length) {
- if (!matchedBroadcastRows.get(i)) {
- buf += joinedRow.withRight(buildRows(i)).copy()
- }
- i += 1
- }
- buf.toSeq
- }
-
sparkContext.union(
matchedStreamRows,
sparkContext.makeRDD(notMatchedBroadcastRows)
@@ -295,13 +305,16 @@ case class BroadcastNestedLoopJoin(
case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) =>
outerJoin(broadcastedRelation)
case (LeftSemi, BuildRight) =>
- leftSemiJoin(broadcastedRelation)
+ leftExistenceJoin(broadcastedRelation, exists = true)
+ case (LeftAnti, BuildRight) =>
+ leftExistenceJoin(broadcastedRelation, exists = false)
case _ =>
/**
* LeftOuter with BuildLeft
* RightOuter with BuildRight
* FullOuter
* LeftSemi with BuildLeft
+ * Anti with BuildLeft
*/
defaultJoin(broadcastedRelation)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
index fb65b50da8..edb4c5a16f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
@@ -28,10 +28,10 @@ import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
/**
- * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD,
- * will be much faster than building the right partition for every row in left RDD, it also
- * materialize the right RDD (in case of the right RDD is nondeterministic).
- */
+ * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD,
+ * will be much faster than building the right partition for every row in left RDD, it also
+ * materialize the right RDD (in case of the right RDD is nondeterministic).
+ */
private[spark]
class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int)
extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 5f42d07273..d6feedc272 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -17,15 +17,12 @@
package org.apache.spark.sql.execution.joins
-import java.util.NoSuchElementException
-
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
import org.apache.spark.sql.execution.metric.LongSQLMetric
-import org.apache.spark.sql.types.{IntegerType, IntegralType, LongType}
-import org.apache.spark.util.collection.CompactBuffer
+import org.apache.spark.sql.types.{IntegralType, LongType}
trait HashJoin {
self: SparkPlan =>
@@ -46,7 +43,7 @@ trait HashJoin {
left.output ++ right.output.map(_.withNullability(true))
case RightOuter =>
left.output.map(_.withNullability(true)) ++ right.output
- case LeftSemi =>
+ case LeftExistence(_) =>
left.output
case x =>
throw new IllegalArgumentException(s"HashJoin should not take $x as the JoinType")
@@ -58,17 +55,23 @@ trait HashJoin {
case BuildRight => (right, left)
}
- protected lazy val (buildKeys, streamedKeys) = buildSide match {
- case BuildLeft => (leftKeys, rightKeys)
- case BuildRight => (rightKeys, leftKeys)
+ protected lazy val (buildKeys, streamedKeys) = {
+ require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType),
+ "Join keys from two sides should have same types")
+ val lkeys = rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output))
+ val rkeys = rewriteKeyExpr(rightKeys).map(BindReferences.bindReference(_, right.output))
+ buildSide match {
+ case BuildLeft => (lkeys, rkeys)
+ case BuildRight => (rkeys, lkeys)
+ }
}
/**
- * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long.
- *
- * If not, returns the original expressions.
- */
- def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = {
+ * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long.
+ *
+ * If not, returns the original expressions.
+ */
+ private def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = {
var keyExpr: Expression = null
var width = 0
keys.foreach { e =>
@@ -83,17 +86,8 @@ trait HashJoin {
width = dt.defaultSize
} else {
val bits = dt.defaultSize * 8
- // hashCode of Long is (l >> 32) ^ l.toInt, it means the hash code of an long with same
- // value in high 32 bit and low 32 bit will be 0. To avoid the worst case that keys
- // with two same ints have hash code 0, we rotate the bits of second one.
- val rotated = if (e.dataType == IntegerType) {
- // (e >>> 15) | (e << 17)
- BitwiseOr(ShiftRightUnsigned(e, Literal(15)), ShiftLeft(e, Literal(17)))
- } else {
- e
- }
keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
- BitwiseAnd(Cast(rotated, LongType), Literal((1L << bits) - 1)))
+ BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1)))
width -= bits
}
// TODO: support BooleanType, DateType and TimestampType
@@ -104,175 +98,129 @@ trait HashJoin {
keyExpr :: Nil
}
- protected lazy val canJoinKeyFitWithinLong: Boolean = {
- val sameTypes = buildKeys.map(_.dataType) == streamedKeys.map(_.dataType)
- val key = rewriteKeyExpr(buildKeys)
- sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType]
- }
-
- protected def buildSideKeyGenerator: Projection =
- UnsafeProjection.create(rewriteKeyExpr(buildKeys), buildPlan.output)
+ protected def buildSideKeyGenerator(): Projection =
+ UnsafeProjection.create(buildKeys)
- protected def streamSideKeyGenerator: Projection =
- UnsafeProjection.create(rewriteKeyExpr(streamedKeys), streamedPlan.output)
+ protected def streamSideKeyGenerator(): UnsafeProjection =
+ UnsafeProjection.create(streamedKeys)
@transient private[this] lazy val boundCondition = if (condition.isDefined) {
- newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
+ newPredicate(condition.get, streamedPlan.output ++ buildPlan.output)
} else {
(r: InternalRow) => true
}
- protected def createResultProjection: (InternalRow) => InternalRow =
- UnsafeProjection.create(self.schema)
-
- protected def hashJoin(
- streamIter: Iterator[InternalRow],
- hashedRelation: HashedRelation,
- numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
- new Iterator[InternalRow] {
- private[this] var currentStreamedRow: InternalRow = _
- private[this] var currentHashMatches: Seq[InternalRow] = _
- private[this] var currentMatchPosition: Int = -1
-
- // Mutable per row objects.
- private[this] val joinRow = new JoinedRow
- private[this] val resultProjection = createResultProjection
-
- private[this] val joinKeys = streamSideKeyGenerator
-
- override final def hasNext: Boolean = {
- while (true) {
- // check if it's end of current matches
- if (currentHashMatches != null && currentMatchPosition == currentHashMatches.length) {
- currentHashMatches = null
- currentMatchPosition = -1
- }
-
- // find the next match
- while (currentHashMatches == null && streamIter.hasNext) {
- currentStreamedRow = streamIter.next()
- val key = joinKeys(currentStreamedRow)
- if (!key.anyNull) {
- currentHashMatches = hashedRelation.get(key)
- if (currentHashMatches != null) {
- currentMatchPosition = 0
- }
- }
- }
- if (currentHashMatches == null) {
- return false
- }
-
- // found some matches
- buildSide match {
- case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
- case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
- }
- if (boundCondition(joinRow)) {
- return true
- } else {
- currentMatchPosition += 1
- }
- }
- false // unreachable
- }
-
- override final def next(): InternalRow = {
- // next() could be called without calling hasNext()
- if (hasNext) {
- currentMatchPosition += 1
- numOutputRows += 1
- resultProjection(joinRow)
- } else {
- throw new NoSuchElementException
- }
- }
+ protected def createResultProjection(): (InternalRow) => InternalRow = {
+ if (joinType == LeftSemi) {
+ UnsafeProjection.create(output, output)
+ } else {
+ // Always put the stream side on left to simplify implementation
+ // both of left and right side could be null
+ UnsafeProjection.create(
+ output, (streamedPlan.output ++ buildPlan.output).map(_.withNullability(true)))
}
}
- @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]()
-
- @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
- @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
-
- protected[this] def leftOuterIterator(
- key: InternalRow,
- joinedRow: JoinedRow,
- rightIter: Iterable[InternalRow],
- resultProjection: InternalRow => InternalRow,
- numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
- val ret: Iterable[InternalRow] = {
- if (!key.anyNull) {
- val temp = if (rightIter != null) {
- rightIter.collect {
- case r if boundCondition(joinedRow.withRight(r)) => {
- numOutputRows += 1
- resultProjection(joinedRow).copy()
- }
- }
- } else {
- List.empty
- }
- if (temp.isEmpty) {
- numOutputRows += 1
- resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
- } else {
- temp
- }
+ private def innerJoin(
+ streamIter: Iterator[InternalRow],
+ hashedRelation: HashedRelation): Iterator[InternalRow] = {
+ val joinRow = new JoinedRow
+ val joinKeys = streamSideKeyGenerator()
+ streamIter.flatMap { srow =>
+ joinRow.withLeft(srow)
+ val matches = hashedRelation.get(joinKeys(srow))
+ if (matches != null) {
+ matches.map(joinRow.withRight(_)).filter(boundCondition)
} else {
- numOutputRows += 1
- resultProjection(joinedRow.withRight(rightNullRow)) :: Nil
+ Seq.empty
}
}
- ret.iterator
}
- protected[this] def rightOuterIterator(
- key: InternalRow,
- leftIter: Iterable[InternalRow],
- joinedRow: JoinedRow,
- resultProjection: InternalRow => InternalRow,
- numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
- val ret: Iterable[InternalRow] = {
- if (!key.anyNull) {
- val temp = if (leftIter != null) {
- leftIter.collect {
- case l if boundCondition(joinedRow.withLeft(l)) => {
- numOutputRows += 1
- resultProjection(joinedRow).copy()
+ private def outerJoin(
+ streamedIter: Iterator[InternalRow],
+ hashedRelation: HashedRelation): Iterator[InternalRow] = {
+ val joinedRow = new JoinedRow()
+ val keyGenerator = streamSideKeyGenerator()
+ val nullRow = new GenericInternalRow(buildPlan.output.length)
+
+ streamedIter.flatMap { currentRow =>
+ val rowKey = keyGenerator(currentRow)
+ joinedRow.withLeft(currentRow)
+ val buildIter = hashedRelation.get(rowKey)
+ new RowIterator {
+ private var found = false
+ override def advanceNext(): Boolean = {
+ while (buildIter != null && buildIter.hasNext) {
+ val nextBuildRow = buildIter.next()
+ if (boundCondition(joinedRow.withRight(nextBuildRow))) {
+ found = true
+ return true
}
}
- } else {
- List.empty
- }
- if (temp.isEmpty) {
- numOutputRows += 1
- resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
- } else {
- temp
+ if (!found) {
+ joinedRow.withRight(nullRow)
+ found = true
+ return true
+ }
+ false
}
- } else {
- numOutputRows += 1
- resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil
- }
+ override def getRow: InternalRow = joinedRow
+ }.toScala
}
- ret.iterator
}
- protected def hashSemiJoin(
- streamIter: Iterator[InternalRow],
- hashedRelation: HashedRelation,
- numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
- val joinKeys = streamSideKeyGenerator
+ private def semiJoin(
+ streamIter: Iterator[InternalRow],
+ hashedRelation: HashedRelation): Iterator[InternalRow] = {
+ val joinKeys = streamSideKeyGenerator()
val joinedRow = new JoinedRow
streamIter.filter { current =>
val key = joinKeys(current)
- lazy val rowBuffer = hashedRelation.get(key)
- val r = !key.anyNull && rowBuffer != null && (condition.isEmpty || rowBuffer.exists {
+ lazy val buildIter = hashedRelation.get(key)
+ !key.anyNull && buildIter != null && (condition.isEmpty || buildIter.exists {
(row: InternalRow) => boundCondition(joinedRow(current, row))
})
- if (r) numOutputRows += 1
- r
+ }
+ }
+
+ private def antiJoin(
+ streamIter: Iterator[InternalRow],
+ hashedRelation: HashedRelation): Iterator[InternalRow] = {
+ val joinKeys = streamSideKeyGenerator()
+ val joinedRow = new JoinedRow
+ streamIter.filter { current =>
+ val key = joinKeys(current)
+ lazy val buildIter = hashedRelation.get(key)
+ key.anyNull || buildIter == null || (condition.isDefined && !buildIter.exists {
+ row => boundCondition(joinedRow(current, row))
+ })
+ }
+ }
+
+ protected def join(
+ streamedIter: Iterator[InternalRow],
+ hashed: HashedRelation,
+ numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
+
+ val joinedIter = joinType match {
+ case Inner =>
+ innerJoin(streamedIter, hashed)
+ case LeftOuter | RightOuter =>
+ outerJoin(streamedIter, hashed)
+ case LeftSemi =>
+ semiJoin(streamedIter, hashed)
+ case LeftAnti =>
+ antiJoin(streamedIter, hashed)
+ case x =>
+ throw new IllegalArgumentException(
+ s"BroadcastHashJoin should not take $x as the JoinType")
+ }
+
+ val resultProj = createResultProjection
+ joinedIter.map { r =>
+ numOutputRows += 1
+ resultProj(r)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 8cc3528639..0427db4e3b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -18,277 +18,189 @@
package org.apache.spark.sql.execution.joins
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
-import java.nio.ByteOrder
-import java.util.{HashMap => JavaHashMap}
-import org.apache.spark.{SparkConf, SparkEnv}
-import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
+import org.apache.spark.{SparkConf, SparkEnv, SparkException}
+import org.apache.spark.memory.{MemoryConsumer, MemoryMode, StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
-import org.apache.spark.sql.execution.SparkSqlSerializer
+import org.apache.spark.sql.types.LongType
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.map.BytesToBytesMap
-import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator, Utils}
-import org.apache.spark.util.collection.CompactBuffer
+import org.apache.spark.util.{KnownSizeEstimation, Utils}
/**
* Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete
* object.
*/
-private[execution] sealed trait HashedRelation {
+private[execution] sealed trait HashedRelation extends KnownSizeEstimation {
/**
- * Returns matched rows.
- */
- def get(key: InternalRow): Seq[InternalRow]
+ * Returns matched rows.
+ *
+ * Returns null if there is no matched rows.
+ */
+ def get(key: InternalRow): Iterator[InternalRow]
/**
- * Returns matched rows for a key that has only one column with LongType.
- */
- def get(key: Long): Seq[InternalRow] = {
+ * Returns matched rows for a key that has only one column with LongType.
+ *
+ * Returns null if there is no matched rows.
+ */
+ def get(key: Long): Iterator[InternalRow] = {
throw new UnsupportedOperationException
}
/**
- * Returns the size of used memory.
- */
- def getMemorySize: Long = 1L // to make the test happy
-
- // This is a helper method to implement Externalizable, and is used by
- // GeneralHashedRelation and UniqueKeyHashedRelation
- protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = {
- out.writeInt(serialized.length) // Write the length of serialized bytes first
- out.write(serialized)
- }
-
- // This is a helper method to implement Externalizable, and is used by
- // GeneralHashedRelation and UniqueKeyHashedRelation
- protected def readBytes(in: ObjectInput): Array[Byte] = {
- val serializedSize = in.readInt() // Read the length of serialized bytes first
- val bytes = new Array[Byte](serializedSize)
- in.readFully(bytes)
- bytes
- }
-}
-
-/**
- * Interface for a hashed relation that have only one row per key.
- *
- * We should call getValue() for better performance.
- */
-private[execution] trait UniqueHashedRelation extends HashedRelation {
-
- /**
- * Returns the matched single row.
- */
+ * Returns the matched single row.
+ */
def getValue(key: InternalRow): InternalRow
/**
- * Returns the matched single row with key that have only one column of LongType.
- */
+ * Returns the matched single row with key that have only one column of LongType.
+ */
def getValue(key: Long): InternalRow = {
throw new UnsupportedOperationException
}
- override def get(key: InternalRow): Seq[InternalRow] = {
- val row = getValue(key)
- if (row != null) {
- CompactBuffer[InternalRow](row)
- } else {
- null
- }
- }
+ /**
+ * Returns true iff all the keys are unique.
+ */
+ def keyIsUnique: Boolean
- override def get(key: Long): Seq[InternalRow] = {
- val row = getValue(key)
- if (row != null) {
- CompactBuffer[InternalRow](row)
- } else {
- null
- }
- }
+ /**
+ * Returns a read-only copy of this, to be safely used in current thread.
+ */
+ def asReadOnlyCopy(): HashedRelation
+
+ /**
+ * Release any used resources.
+ */
+ def close(): Unit
}
private[execution] object HashedRelation {
/**
* Create a HashedRelation from an Iterator of InternalRow.
- *
- * Note: The caller should make sure that these InternalRow are different objects.
*/
def apply(
- canJoinKeyFitWithinLong: Boolean,
input: Iterator[InternalRow],
- keyGenerator: Projection,
- sizeEstimate: Int = 64): HashedRelation = {
+ key: Seq[Expression],
+ sizeEstimate: Int = 64,
+ taskMemoryManager: TaskMemoryManager = null): HashedRelation = {
+ val mm = Option(taskMemoryManager).getOrElse {
+ new TaskMemoryManager(
+ new StaticMemoryManager(
+ new SparkConf().set("spark.memory.offHeap.enabled", "false"),
+ Long.MaxValue,
+ Long.MaxValue,
+ 1),
+ 0)
+ }
- if (canJoinKeyFitWithinLong) {
- LongHashedRelation(input, keyGenerator, sizeEstimate)
+ if (key.length == 1 && key.head.dataType == LongType) {
+ LongHashedRelation(input, key, sizeEstimate, mm)
} else {
- UnsafeHashedRelation(
- input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate)
+ UnsafeHashedRelation(input, key, sizeEstimate, mm)
}
}
}
/**
- * A HashedRelation for UnsafeRow, which is backed by HashMap or BytesToBytesMap that maps the key
- * into a sequence of values.
- *
- * When it's created, it uses HashMap. After it's serialized and deserialized, it switch to use
- * BytesToBytesMap for better memory performance (multiple values for the same are stored as a
- * continuous byte array.
+ * A HashedRelation for UnsafeRow, which is backed BytesToBytesMap.
*
* It's serialized in the following format:
* [number of keys]
- * [size of key] [size of all values in bytes] [key bytes] [bytes for all values]
- * ...
- *
- * All the values are serialized as following:
- * [number of fields] [number of bytes] [underlying bytes of UnsafeRow]
- * ...
+ * [size of key] [size of value] [key bytes] [bytes for value]
*/
-private[joins] final class UnsafeHashedRelation(
- private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]])
- extends HashedRelation
- with KnownSizeEstimation
- with Externalizable {
+private[joins] class UnsafeHashedRelation(
+ private var numFields: Int,
+ private var binaryMap: BytesToBytesMap)
+ extends HashedRelation with Externalizable {
- private[joins] def this() = this(null) // Needed for serialization
+ private[joins] def this() = this(0, null) // Needed for serialization
- // Use BytesToBytesMap in executor for better performance (it's created when deserialization)
- // This is used in broadcast joins and distributed mode only
- @transient private[this] var binaryMap: BytesToBytesMap = _
+ override def keyIsUnique: Boolean = binaryMap.numKeys() == binaryMap.numValues()
- /**
- * Return the size of the unsafe map on the executors.
- *
- * For broadcast joins, this hashed relation is bigger on the driver because it is
- * represented as a Java hash map there. While serializing the map to the executors,
- * however, we rehash the contents in a binary map to reduce the memory footprint on
- * the executors.
- *
- * For non-broadcast joins or in local mode, return 0.
- */
- override def getMemorySize: Long = {
- if (binaryMap != null) {
- binaryMap.getTotalMemoryConsumption
- } else {
- 0
- }
+ override def asReadOnlyCopy(): UnsafeHashedRelation = {
+ new UnsafeHashedRelation(numFields, binaryMap)
}
- override def estimatedSize: Long = {
- if (binaryMap != null) {
- binaryMap.getTotalMemoryConsumption
- } else {
- SizeEstimator.estimate(hashTable)
- }
- }
+ override def estimatedSize: Long = binaryMap.getTotalMemoryConsumption
- override def get(key: InternalRow): Seq[InternalRow] = {
- val unsafeKey = key.asInstanceOf[UnsafeRow]
+ // re-used in get()/getValue()
+ var resultRow = new UnsafeRow(numFields)
- if (binaryMap != null) {
- // Used in Broadcast join
- val map = binaryMap // avoid the compiler error
- val loc = new map.Location // this could be allocated in stack
- binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
- unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
- if (loc.isDefined) {
- val buffer = CompactBuffer[UnsafeRow]()
-
- val base = loc.getValueBase
- var offset = loc.getValueOffset
- val last = offset + loc.getValueLength
- while (offset < last) {
- val numFields = Platform.getInt(base, offset)
- val sizeInBytes = Platform.getInt(base, offset + 4)
- offset += 8
-
- val row = new UnsafeRow(numFields)
- row.pointTo(base, offset, sizeInBytes)
- buffer += row
- offset += sizeInBytes
+ override def get(key: InternalRow): Iterator[InternalRow] = {
+ val unsafeKey = key.asInstanceOf[UnsafeRow]
+ val map = binaryMap // avoid the compiler error
+ val loc = new map.Location // this could be allocated in stack
+ binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
+ unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
+ if (loc.isDefined) {
+ new Iterator[UnsafeRow] {
+ private var _hasNext = true
+ override def hasNext: Boolean = _hasNext
+ override def next(): UnsafeRow = {
+ resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+ _hasNext = loc.nextValue()
+ resultRow
}
- buffer
- } else {
- null
}
+ } else {
+ null
+ }
+ }
+ def getValue(key: InternalRow): InternalRow = {
+ val unsafeKey = key.asInstanceOf[UnsafeRow]
+ val map = binaryMap // avoid the compiler error
+ val loc = new map.Location // this could be allocated in stack
+ binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
+ unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
+ if (loc.isDefined) {
+ resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+ resultRow
} else {
- // Use the Java HashMap in local mode or for non-broadcast joins (e.g. ShuffleHashJoin)
- hashTable.get(unsafeKey)
+ null
}
}
- override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
- if (binaryMap != null) {
- // This could happen when a cached broadcast object need to be dumped into disk to free memory
- out.writeInt(binaryMap.numElements())
-
- var buffer = new Array[Byte](64)
- def write(base: Object, offset: Long, length: Int): Unit = {
- if (buffer.length < length) {
- buffer = new Array[Byte](length)
- }
- Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length)
- out.write(buffer, 0, length)
- }
+ override def close(): Unit = {
+ binaryMap.free()
+ }
- val iter = binaryMap.iterator()
- while (iter.hasNext) {
- val loc = iter.next()
- // [key size] [values size] [key bytes] [values bytes]
- out.writeInt(loc.getKeyLength)
- out.writeInt(loc.getValueLength)
- write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength)
- write(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
+ out.writeInt(numFields)
+ // TODO: move these into BytesToBytesMap
+ out.writeInt(binaryMap.numKeys())
+ out.writeInt(binaryMap.numValues())
+
+ var buffer = new Array[Byte](64)
+ def write(base: Object, offset: Long, length: Int): Unit = {
+ if (buffer.length < length) {
+ buffer = new Array[Byte](length)
}
+ Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length)
+ out.write(buffer, 0, length)
+ }
- } else {
- assert(hashTable != null)
- out.writeInt(hashTable.size())
-
- val iter = hashTable.entrySet().iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- val key = entry.getKey
- val values = entry.getValue
-
- // write all the values as single byte array
- var totalSize = 0L
- var i = 0
- while (i < values.length) {
- totalSize += values(i).getSizeInBytes + 4 + 4
- i += 1
- }
- assert(totalSize < Integer.MAX_VALUE, "values are too big")
-
- // [key size] [values size] [key bytes] [values bytes]
- out.writeInt(key.getSizeInBytes)
- out.writeInt(totalSize.toInt)
- out.write(key.getBytes)
- i = 0
- while (i < values.length) {
- // [num of fields] [num of bytes] [row bytes]
- // write the integer in native order, so they can be read by UNSAFE.getInt()
- if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) {
- out.writeInt(values(i).numFields())
- out.writeInt(values(i).getSizeInBytes)
- } else {
- out.writeInt(Integer.reverseBytes(values(i).numFields()))
- out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes))
- }
- out.write(values(i).getBytes)
- i += 1
- }
- }
+ val iter = binaryMap.iterator()
+ while (iter.hasNext) {
+ val loc = iter.next()
+ // [key size] [values size] [key bytes] [value bytes]
+ out.writeInt(loc.getKeyLength)
+ out.writeInt(loc.getValueLength)
+ write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength)
+ write(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
}
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
+ numFields = in.readInt()
+ resultRow = new UnsafeRow(numFields)
val nKeys = in.readInt()
+ val nValues = in.readInt()
// This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
// TODO(josh): This needs to be revisited before we merge this patch; making this change now
// so that tests compile:
@@ -314,7 +226,7 @@ private[joins] final class UnsafeHashedRelation(
var i = 0
var keyBuffer = new Array[Byte](1024)
var valuesBuffer = new Array[Byte](1024)
- while (i < nKeys) {
+ while (i < nValues) {
val keySize = in.readInt()
val valuesSize = in.readInt()
if (keySize > keyBuffer.length) {
@@ -326,13 +238,11 @@ private[joins] final class UnsafeHashedRelation(
}
in.readFully(valuesBuffer, 0, valuesSize)
- // put it into binary map
val loc = binaryMap.lookup(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize)
- assert(!loc.isDefined, "Duplicated key found!")
- val putSuceeded = loc.putNewKey(
- keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize,
+ val putSuceeded = loc.append(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize,
valuesBuffer, Platform.BYTE_ARRAY_OFFSET, valuesSize)
if (!putSuceeded) {
+ binaryMap.free()
throw new IOException("Could not allocate memory to grow BytesToBytesMap")
}
i += 1
@@ -344,279 +254,503 @@ private[joins] object UnsafeHashedRelation {
def apply(
input: Iterator[InternalRow],
- keyGenerator: UnsafeProjection,
- sizeEstimate: Int): HashedRelation = {
+ key: Seq[Expression],
+ sizeEstimate: Int,
+ taskMemoryManager: TaskMemoryManager): HashedRelation = {
- // Use a Java hash table here because unsafe maps expect fixed size records
- // TODO: Use BytesToBytesMap for memory efficiency
- val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate)
+ val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes)
+ .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m"))
+
+ val binaryMap = new BytesToBytesMap(
+ taskMemoryManager,
+ // Only 70% of the slots can be used before growing, more capacity help to reduce collision
+ (sizeEstimate * 1.5 + 1).toInt,
+ pageSizeBytes)
// Create a mapping of buildKeys -> rows
+ val keyGenerator = UnsafeProjection.create(key)
+ var numFields = 0
while (input.hasNext) {
- val unsafeRow = input.next().asInstanceOf[UnsafeRow]
- val rowKey = keyGenerator(unsafeRow)
- if (!rowKey.anyNull) {
- val existingMatchList = hashTable.get(rowKey)
- val matchList = if (existingMatchList == null) {
- val newMatchList = new CompactBuffer[UnsafeRow]()
- hashTable.put(rowKey.copy(), newMatchList)
- newMatchList
- } else {
- existingMatchList
+ val row = input.next().asInstanceOf[UnsafeRow]
+ numFields = row.numFields()
+ val key = keyGenerator(row)
+ if (!key.anyNull) {
+ val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
+ val success = loc.append(
+ key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
+ row.getBaseObject, row.getBaseOffset, row.getSizeInBytes)
+ if (!success) {
+ binaryMap.free()
+ throw new SparkException("There is no enough memory to build hash map")
}
- matchList += unsafeRow
}
}
- // TODO: create UniqueUnsafeRelation
- new UnsafeHashedRelation(hashTable)
+ new UnsafeHashedRelation(numFields, binaryMap)
}
}
/**
- * An interface for a hashed relation that the key is a Long.
- */
-private[joins] trait LongHashedRelation extends HashedRelation {
- override def get(key: InternalRow): Seq[InternalRow] = {
- get(key.getLong(0))
+ * An append-only hash map mapping from key of Long to UnsafeRow.
+ *
+ * The underlying bytes of all values (UnsafeRows) are packed together as a single byte array
+ * (`page`) in this format:
+ *
+ * [bytes of row1][address1][bytes of row2][address1] ...
+ *
+ * address1 (8 bytes) is the offset and size of next value for the same key as row1, any key
+ * could have multiple values. the address at the end of last value for every key is 0.
+ *
+ * The keys and addresses of their values could be stored in two modes:
+ *
+ * 1) sparse mode: the keys and addresses are stored in `array` as:
+ *
+ * [key1][address1][key2][address2]...[]
+ *
+ * address1 (Long) is the offset (in `page`) and size of the value for key1. The position of key1
+ * is determined by `key1 % cap`. Quadratic probing with triangular numbers is used to address
+ * hash collision.
+ *
+ * 2) dense mode: all the addresses are packed into a single array of long, as:
+ *
+ * [address1] [address2] ...
+ *
+ * address1 (Long) is the offset (in `page`) and size of the value for key1, the position is
+ * determined by `key1 - minKey`.
+ *
+ * The map is created as sparse mode, then key-value could be appended into it. Once finish
+ * appending, caller could all optimize() to try to turn the map into dense mode, which is faster
+ * to probe.
+ *
+ * see http://java-performance.info/implementing-world-fastest-java-int-to-int-hash-map/
+ */
+private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int)
+ extends MemoryConsumer(mm) with Externalizable {
+
+ // Whether the keys are stored in dense mode or not.
+ private var isDense = false
+
+ // The minimum key
+ private var minKey = Long.MaxValue
+
+ // The maxinum key
+ private var maxKey = Long.MinValue
+
+ // The array to store the key and offset of UnsafeRow in the page.
+ //
+ // Sparse mode: [key1] [offset1 | size1] [key2] [offset | size2] ...
+ // Dense mode: [offset1 | size1] [offset2 | size2]
+ private var array: Array[Long] = null
+ private var mask: Int = 0
+
+ // The page to store all bytes of UnsafeRow and the pointer to next rows.
+ // [row1][pointer1] [row2][pointer2]
+ private var page: Array[Byte] = null
+
+ // Current write cursor in the page.
+ private var cursor = Platform.BYTE_ARRAY_OFFSET
+
+ // The total number of values of all keys.
+ private var numValues = 0
+
+ // The number of unique keys.
+ private var numKeys = 0
+
+ // needed by serializer
+ def this() = {
+ this(
+ new TaskMemoryManager(
+ new StaticMemoryManager(
+ new SparkConf().set("spark.memory.offHeap.enabled", "false"),
+ Long.MaxValue,
+ Long.MaxValue,
+ 1),
+ 0),
+ 0)
}
-}
-private[joins] final class GeneralLongHashedRelation(
- private var hashTable: JavaHashMap[Long, CompactBuffer[UnsafeRow]])
- extends LongHashedRelation with Externalizable {
+ private def acquireMemory(size: Long): Unit = {
+ // do not support spilling
+ val got = mm.acquireExecutionMemory(size, MemoryMode.ON_HEAP, this)
+ if (got < size) {
+ freeMemory(got)
+ throw new SparkException(s"Can't acquire $size bytes memory to build hash relation, " +
+ s"got $got bytes")
+ }
+ }
- // Needed for serialization (it is public to make Java serialization work)
- def this() = this(null)
+ private def freeMemory(size: Long): Unit = {
+ mm.releaseExecutionMemory(size, MemoryMode.ON_HEAP, this)
+ }
+
+ private def init(): Unit = {
+ if (mm != null) {
+ var n = 1
+ while (n < capacity) n *= 2
+ acquireMemory(n * 2 * 8 + (1 << 20))
+ array = new Array[Long](n * 2)
+ mask = n * 2 - 2
+ page = new Array[Byte](1 << 20) // 1M bytes
+ }
+ }
- override def get(key: Long): Seq[InternalRow] = hashTable.get(key)
+ init()
- override def writeExternal(out: ObjectOutput): Unit = {
- writeBytes(out, SparkSqlSerializer.serialize(hashTable))
+ def spill(size: Long, trigger: MemoryConsumer): Long = 0L
+
+ /**
+ * Returns whether all the keys are unique.
+ */
+ def keyIsUnique: Boolean = numKeys == numValues
+
+ /**
+ * Returns total memory consumption.
+ */
+ def getTotalMemoryConsumption: Long = array.length * 8 + page.length
+
+ /**
+ * Returns the first slot of array that store the keys (sparse mode).
+ */
+ private def firstSlot(key: Long): Int = {
+ val h = key * 0x9E3779B9L
+ (h ^ (h >> 32)).toInt & mask
}
- override def readExternal(in: ObjectInput): Unit = {
- hashTable = SparkSqlSerializer.deserialize(readBytes(in))
+ /**
+ * Returns the next probe in the array.
+ */
+ private def nextSlot(pos: Int): Int = (pos + 2) & mask
+
+ private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = {
+ val offset = address >>> 32
+ val size = address & 0xffffffffL
+ resultRow.pointTo(page, offset, size.toInt)
+ resultRow
}
-}
-private[joins] final class UniqueLongHashedRelation(
- private var hashTable: JavaHashMap[Long, UnsafeRow])
- extends UniqueHashedRelation with LongHashedRelation with Externalizable {
+ /**
+ * Returns the single UnsafeRow for given key, or null if not found.
+ */
+ def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = {
+ if (isDense) {
+ val idx = (key - minKey).toInt
+ if (idx >= 0 && key <= maxKey && array(idx) > 0) {
+ return getRow(array(idx), resultRow)
+ }
+ } else {
+ var pos = firstSlot(key)
+ while (array(pos + 1) != 0) {
+ if (array(pos) == key) {
+ return getRow(array(pos + 1), resultRow)
+ }
+ pos = nextSlot(pos)
+ }
+ }
+ null
+ }
- // Needed for serialization (it is public to make Java serialization work)
- def this() = this(null)
+ /**
+ * Returns an interator of UnsafeRow for multiple linked values.
+ */
+ private def valueIter(address: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = {
+ new Iterator[UnsafeRow] {
+ var addr = address
+ override def hasNext: Boolean = addr != 0
+ override def next(): UnsafeRow = {
+ val offset = addr >>> 32
+ val size = addr & 0xffffffffL
+ resultRow.pointTo(page, offset, size.toInt)
+ addr = Platform.getLong(page, offset + size)
+ resultRow
+ }
+ }
+ }
- override def getValue(key: InternalRow): InternalRow = {
- getValue(key.getLong(0))
+ /**
+ * Returns an iterator for all the values for the given key, or null if no value found.
+ */
+ def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = {
+ if (isDense) {
+ val idx = (key - minKey).toInt
+ if (idx >=0 && key <= maxKey && array(idx) > 0) {
+ return valueIter(array(idx), resultRow)
+ }
+ } else {
+ var pos = firstSlot(key)
+ while (array(pos + 1) != 0) {
+ if (array(pos) == key) {
+ return valueIter(array(pos + 1), resultRow)
+ }
+ pos = nextSlot(pos)
+ }
+ }
+ null
}
- override def getValue(key: Long): InternalRow = {
- hashTable.get(key)
+ /**
+ * Appends the key and row into this map.
+ */
+ def append(key: Long, row: UnsafeRow): Unit = {
+ if (key < minKey) {
+ minKey = key
+ }
+ if (key > maxKey) {
+ maxKey = key
+ }
+
+ // There is 8 bytes for the pointer to next value
+ if (cursor + 8 + row.getSizeInBytes > page.length + Platform.BYTE_ARRAY_OFFSET) {
+ val used = page.length
+ if (used * 2L > (1L << 31)) {
+ sys.error("Can't allocate a page that is larger than 2G")
+ }
+ acquireMemory(used * 2)
+ val newPage = new Array[Byte](used * 2)
+ System.arraycopy(page, 0, newPage, 0, cursor - Platform.BYTE_ARRAY_OFFSET)
+ page = newPage
+ freeMemory(used)
+ }
+
+ // copy the bytes of UnsafeRow
+ val offset = cursor
+ Platform.copyMemory(row.getBaseObject, row.getBaseOffset, page, cursor, row.getSizeInBytes)
+ cursor += row.getSizeInBytes
+ Platform.putLong(page, cursor, 0)
+ cursor += 8
+ numValues += 1
+ updateIndex(key, (offset.toLong << 32) | row.getSizeInBytes)
+ }
+
+ /**
+ * Update the address in array for given key.
+ */
+ private def updateIndex(key: Long, address: Long): Unit = {
+ var pos = firstSlot(key)
+ while (array(pos) != key && array(pos + 1) != 0) {
+ pos = nextSlot(pos)
+ }
+ if (array(pos + 1) == 0) {
+ // this is the first value for this key, put the address in array.
+ array(pos) = key
+ array(pos + 1) = address
+ numKeys += 1
+ if (numKeys * 4 > array.length) {
+ // reach half of the capacity
+ growArray()
+ }
+ } else {
+ // there are some values for this key, put the address in the front of them.
+ val pointer = (address >>> 32) + (address & 0xffffffffL)
+ Platform.putLong(page, pointer, array(pos + 1))
+ array(pos + 1) = address
+ }
+ }
+
+ private def growArray(): Unit = {
+ var old_array = array
+ val n = array.length
+ numKeys = 0
+ acquireMemory(n * 2 * 8)
+ array = new Array[Long](n * 2)
+ mask = n * 2 - 2
+ var i = 0
+ while (i < old_array.length) {
+ if (old_array(i + 1) > 0) {
+ updateIndex(old_array(i), old_array(i + 1))
+ }
+ i += 2
+ }
+ old_array = null // release the reference to old array
+ freeMemory(n * 8)
+ }
+
+ /**
+ * Try to turn the map into dense mode, which is faster to probe.
+ */
+ def optimize(): Unit = {
+ val range = maxKey - minKey
+ // Convert to dense mode if it does not require more memory or could fit within L1 cache
+ if (range < array.length || range < 1024) {
+ try {
+ acquireMemory((range + 1) * 8)
+ } catch {
+ case e: SparkException =>
+ // there is no enough memory to convert
+ return
+ }
+ val denseArray = new Array[Long]((range + 1).toInt)
+ var i = 0
+ while (i < array.length) {
+ if (array(i + 1) > 0) {
+ val idx = (array(i) - minKey).toInt
+ denseArray(idx) = array(i + 1)
+ }
+ i += 2
+ }
+ val old_length = array.length
+ array = denseArray
+ isDense = true
+ freeMemory(old_length * 8)
+ }
+ }
+
+ /**
+ * Free all the memory acquired by this map.
+ */
+ def free(): Unit = {
+ if (page != null) {
+ freeMemory(page.length)
+ page = null
+ }
+ if (array != null) {
+ freeMemory(array.length * 8)
+ array = null
+ }
}
override def writeExternal(out: ObjectOutput): Unit = {
- writeBytes(out, SparkSqlSerializer.serialize(hashTable))
+ out.writeBoolean(isDense)
+ out.writeLong(minKey)
+ out.writeLong(maxKey)
+ out.writeInt(numKeys)
+ out.writeInt(numValues)
+
+ out.writeInt(array.length)
+ val buffer = new Array[Byte](4 << 10)
+ var offset = Platform.LONG_ARRAY_OFFSET
+ val end = array.length * 8 + Platform.LONG_ARRAY_OFFSET
+ while (offset < end) {
+ val size = Math.min(buffer.length, end - offset)
+ Platform.copyMemory(array, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size)
+ out.write(buffer, 0, size)
+ offset += size
+ }
+
+ val used = cursor - Platform.BYTE_ARRAY_OFFSET
+ out.writeInt(used)
+ out.write(page, 0, used)
}
override def readExternal(in: ObjectInput): Unit = {
- hashTable = SparkSqlSerializer.deserialize(readBytes(in))
+ isDense = in.readBoolean()
+ minKey = in.readLong()
+ maxKey = in.readLong()
+ numKeys = in.readInt()
+ numValues = in.readInt()
+
+ val length = in.readInt()
+ array = new Array[Long](length)
+ mask = length - 2
+ val buffer = new Array[Byte](4 << 10)
+ var offset = Platform.LONG_ARRAY_OFFSET
+ val end = length * 8 + Platform.LONG_ARRAY_OFFSET
+ while (offset < end) {
+ val size = Math.min(buffer.length, end - offset)
+ in.readFully(buffer, 0, size)
+ Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size)
+ offset += size
+ }
+
+ val numBytes = in.readInt()
+ page = new Array[Byte](numBytes)
+ in.readFully(page)
}
}
-/**
- * A relation that pack all the rows into a byte array, together with offsets and sizes.
- *
- * All the bytes of UnsafeRow are packed together as `bytes`:
- *
- * [ Row0 ][ Row1 ][] ... [ RowN ]
- *
- * With keys:
- *
- * start start+1 ... start+N
- *
- * `offsets` are offsets of UnsafeRows in the `bytes`
- * `sizes` are the numbers of bytes of UnsafeRows, 0 means no row for this key.
- *
- * For example, two UnsafeRows (24 bytes and 32 bytes), with keys as 3 and 5 will stored as:
- *
- * start = 3
- * offsets = [0, 0, 24]
- * sizes = [24, 0, 32]
- * bytes = [0 - 24][][24 - 56]
- */
-private[joins] final class LongArrayRelation(
- private var numFields: Int,
- private var start: Long,
- private var offsets: Array[Int],
- private var sizes: Array[Int],
- private var bytes: Array[Byte]
- ) extends UniqueHashedRelation with LongHashedRelation with Externalizable {
+private[joins] class LongHashedRelation(
+ private var nFields: Int,
+ private var map: LongToUnsafeRowMap) extends HashedRelation with Externalizable {
+
+ private var resultRow: UnsafeRow = new UnsafeRow(nFields)
// Needed for serialization (it is public to make Java serialization work)
- def this() = this(0, 0L, null, null, null)
+ def this() = this(0, null)
- override def getValue(key: InternalRow): InternalRow = {
- getValue(key.getLong(0))
- }
+ override def asReadOnlyCopy(): LongHashedRelation = new LongHashedRelation(nFields, map)
- override def getMemorySize: Long = {
- offsets.length * 4 + sizes.length * 4 + bytes.length
- }
+ override def estimatedSize: Long = map.getTotalMemoryConsumption
- override def getValue(key: Long): InternalRow = {
- val idx = (key - start).toInt
- if (idx >= 0 && idx < sizes.length && sizes(idx) > 0) {
- val result = new UnsafeRow(numFields)
- result.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx))
- result
+ override def get(key: InternalRow): Iterator[InternalRow] = {
+ if (key.isNullAt(0)) {
+ null
} else {
+ get(key.getLong(0))
+ }
+ }
+
+ override def getValue(key: InternalRow): InternalRow = {
+ if (key.isNullAt(0)) {
null
+ } else {
+ getValue(key.getLong(0))
}
}
+ override def get(key: Long): Iterator[InternalRow] = map.get(key, resultRow)
+
+ override def getValue(key: Long): InternalRow = map.getValue(key, resultRow)
+
+ override def keyIsUnique: Boolean = map.keyIsUnique
+
+ override def close(): Unit = {
+ map.free()
+ }
+
override def writeExternal(out: ObjectOutput): Unit = {
- out.writeInt(numFields)
- out.writeLong(start)
- out.writeInt(sizes.length)
- var i = 0
- while (i < sizes.length) {
- out.writeInt(sizes(i))
- i += 1
- }
- out.writeInt(bytes.length)
- out.write(bytes)
+ out.writeInt(nFields)
+ out.writeObject(map)
}
override def readExternal(in: ObjectInput): Unit = {
- numFields = in.readInt()
- start = in.readLong()
- val length = in.readInt()
- // read sizes of rows
- sizes = new Array[Int](length)
- offsets = new Array[Int](length)
- var i = 0
- var offset = 0
- while (i < length) {
- offsets(i) = offset
- sizes(i) = in.readInt()
- offset += sizes(i)
- i += 1
- }
- // read all the bytes
- val total = in.readInt()
- assert(total == offset)
- bytes = new Array[Byte](total)
- in.readFully(bytes)
+ nFields = in.readInt()
+ resultRow = new UnsafeRow(nFields)
+ map = in.readObject().asInstanceOf[LongToUnsafeRowMap]
}
}
/**
- * Create hashed relation with key that is long.
- */
+ * Create hashed relation with key that is long.
+ */
private[joins] object LongHashedRelation {
-
- val DENSE_FACTOR = 0.2
-
def apply(
- input: Iterator[InternalRow],
- keyGenerator: Projection,
- sizeEstimate: Int): HashedRelation = {
+ input: Iterator[InternalRow],
+ key: Seq[Expression],
+ sizeEstimate: Int,
+ taskMemoryManager: TaskMemoryManager): LongHashedRelation = {
- // Use a Java hash table here because unsafe maps expect fixed size records
- val hashTable = new JavaHashMap[Long, CompactBuffer[UnsafeRow]](sizeEstimate)
+ val map: LongToUnsafeRowMap = new LongToUnsafeRowMap(taskMemoryManager, sizeEstimate)
+ val keyGenerator = UnsafeProjection.create(key)
// Create a mapping of key -> rows
var numFields = 0
- var keyIsUnique = true
- var minKey = Long.MaxValue
- var maxKey = Long.MinValue
while (input.hasNext) {
val unsafeRow = input.next().asInstanceOf[UnsafeRow]
numFields = unsafeRow.numFields()
val rowKey = keyGenerator(unsafeRow)
- if (!rowKey.anyNull) {
+ if (!rowKey.isNullAt(0)) {
val key = rowKey.getLong(0)
- minKey = math.min(minKey, key)
- maxKey = math.max(maxKey, key)
- val existingMatchList = hashTable.get(key)
- val matchList = if (existingMatchList == null) {
- val newMatchList = new CompactBuffer[UnsafeRow]()
- hashTable.put(key, newMatchList)
- newMatchList
- } else {
- keyIsUnique = false
- existingMatchList
- }
- matchList += unsafeRow
+ map.append(key, unsafeRow)
}
}
-
- if (keyIsUnique) {
- if (hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) {
- // The keys are dense enough, so use LongArrayRelation
- val length = (maxKey - minKey).toInt + 1
- val sizes = new Array[Int](length)
- val offsets = new Array[Int](length)
- var offset = 0
- var i = 0
- while (i < length) {
- val rows = hashTable.get(i + minKey)
- if (rows != null) {
- offsets(i) = offset
- sizes(i) = rows(0).getSizeInBytes
- offset += sizes(i)
- }
- i += 1
- }
- val bytes = new Array[Byte](offset)
- i = 0
- while (i < length) {
- val rows = hashTable.get(i + minKey)
- if (rows != null) {
- rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i))
- }
- i += 1
- }
- new LongArrayRelation(numFields, minKey, offsets, sizes, bytes)
-
- } else {
- // all the keys are unique, one row per key.
- val uniqHashTable = new JavaHashMap[Long, UnsafeRow](hashTable.size)
- val iter = hashTable.entrySet().iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- uniqHashTable.put(entry.getKey, entry.getValue()(0))
- }
- new UniqueLongHashedRelation(uniqHashTable)
- }
- } else {
- new GeneralLongHashedRelation(hashTable)
- }
+ map.optimize()
+ new LongHashedRelation(numFields, map)
}
}
/** The HashedRelationBroadcastMode requires that rows are broadcasted as a HashedRelation. */
-private[execution] case class HashedRelationBroadcastMode(
- canJoinKeyFitWithinLong: Boolean,
- keys: Seq[Expression],
- attributes: Seq[Attribute]) extends BroadcastMode {
+private[execution] case class HashedRelationBroadcastMode(key: Seq[Expression])
+ extends BroadcastMode {
override def transform(rows: Array[InternalRow]): HashedRelation = {
- val generator = UnsafeProjection.create(keys, attributes)
- HashedRelation(canJoinKeyFitWithinLong, rows.iterator, generator, rows.length)
+ HashedRelation(rows.iterator, canonicalizedKey, rows.length)
}
- private lazy val canonicalizedKeys: Seq[Expression] = {
- keys.map { e =>
- BindReferences.bindReference(e.canonicalized, attributes)
- }
+ private lazy val canonicalizedKey: Seq[Expression] = {
+ key.map { e => e.canonicalized }
}
override def compatibleWith(other: BroadcastMode): Boolean = other match {
- case m: HashedRelationBroadcastMode =>
- canJoinKeyFitWithinLong == m.canJoinKeyFitWithinLong &&
- canonicalizedKeys == m.canonicalizedKeys
+ case m: HashedRelationBroadcastMode => canonicalizedKey == m.canonicalizedKey
case _ => false
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
index 5c4f1ef60f..0c3e3c3fc1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -17,11 +17,10 @@
package org.apache.spark.sql.execution.joins
-import org.apache.spark.{SparkException, TaskContext}
-import org.apache.spark.memory.MemoryMode
+import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{BindReferences, Expression, UnsafeRow}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
@@ -45,6 +44,7 @@ case class ShuffledHashJoin(
override def outputPartitioning: Partitioning = joinType match {
case Inner => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
+ case LeftAnti => left.outputPartitioning
case LeftSemi => left.outputPartitioning
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
@@ -56,75 +56,21 @@ case class ShuffledHashJoin(
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
- private def buildHashedRelation(iter: Iterator[UnsafeRow]): HashedRelation = {
- // try to acquire some memory for the hash table, it could trigger other operator to free some
- // memory. The memory acquired here will mostly be used until the end of task.
+ private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
val context = TaskContext.get()
- val memoryManager = context.taskMemoryManager()
- var acquired = 0L
- var used = 0L
+ val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager())
+ // This relation is usually used until the end of task.
context.addTaskCompletionListener((t: TaskContext) =>
- memoryManager.releaseExecutionMemory(acquired, MemoryMode.ON_HEAP, null)
+ relation.close()
)
-
- val copiedIter = iter.map { row =>
- // It's hard to guess what's exactly memory will be used, we have a rough guess here.
- // TODO: use BytesToBytesMap instead of HashMap for memory efficiency
- // Each pair in HashMap will have two UnsafeRows, one CompactBuffer, maybe 10+ pointers
- val needed = 150 + row.getSizeInBytes
- if (needed > acquired - used) {
- val got = memoryManager.acquireExecutionMemory(
- Math.max(memoryManager.pageSizeBytes(), needed), MemoryMode.ON_HEAP, null)
- if (got < needed) {
- throw new SparkException("Can't acquire enough memory to build hash map in shuffled" +
- "hash join, please use sort merge join by setting " +
- "spark.sql.join.preferSortMergeJoin=true")
- }
- acquired += got
- }
- used += needed
- // HashedRelation requires that the UnsafeRow should be separate objects.
- row.copy()
- }
-
- HashedRelation(canJoinKeyFitWithinLong, copiedIter, buildSideKeyGenerator)
+ relation
}
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
-
streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
- val hashed = buildHashedRelation(buildIter.asInstanceOf[Iterator[UnsafeRow]])
- val joinedRow = new JoinedRow
- joinType match {
- case Inner =>
- hashJoin(streamIter, hashed, numOutputRows)
-
- case LeftSemi =>
- hashSemiJoin(streamIter, hashed, numOutputRows)
-
- case LeftOuter =>
- val keyGenerator = streamSideKeyGenerator
- val resultProj = createResultProjection
- streamIter.flatMap(currentRow => {
- val rowKey = keyGenerator(currentRow)
- joinedRow.withLeft(currentRow)
- leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj, numOutputRows)
- })
-
- case RightOuter =>
- val keyGenerator = streamSideKeyGenerator
- val resultProj = createResultProjection
- streamIter.flatMap(currentRow => {
- val rowKey = keyGenerator(currentRow)
- joinedRow.withRight(currentRow)
- rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj, numOutputRows)
- })
-
- case x =>
- throw new IllegalArgumentException(
- s"ShuffledHashJoin should not take $x as the JoinType")
- }
+ val hashed = buildHashedRelation(buildIter)
+ join(streamIter, hashed, numOutputRows)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
index 60bd8ea39a..0e7b2f2f31 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -256,9 +256,9 @@ case class SortMergeJoin(
}
/**
- * Generate a function to scan both left and right to find a match, returns the term for
- * matched one row from left side and buffered rows from right side.
- */
+ * Generate a function to scan both left and right to find a match, returns the term for
+ * matched one row from left side and buffered rows from right side.
+ */
private def genScanner(ctx: CodegenContext): (String, String) = {
// Create class member for next row from both sides.
val leftRow = ctx.freshName("leftRow")
@@ -341,12 +341,12 @@ case class SortMergeJoin(
}
/**
- * Creates variables for left part of result row.
- *
- * In order to defer the access after condition and also only access once in the loop,
- * the variables should be declared separately from accessing the columns, we can't use the
- * codegen of BoundReference here.
- */
+ * Creates variables for left part of result row.
+ *
+ * In order to defer the access after condition and also only access once in the loop,
+ * the variables should be declared separately from accessing the columns, we can't use the
+ * codegen of BoundReference here.
+ */
private def createLeftVars(ctx: CodegenContext, leftRow: String): Seq[ExprCode] = {
ctx.INPUT_ROW = leftRow
left.output.zipWithIndex.map { case (a, i) =>
@@ -370,9 +370,9 @@ case class SortMergeJoin(
}
/**
- * Creates the variables for right part of result row, using BoundReference, since the right
- * part are accessed inside the loop.
- */
+ * Creates the variables for right part of result row, using BoundReference, since the right
+ * part are accessed inside the loop.
+ */
private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = {
ctx.INPUT_ROW = rightRow
right.output.zipWithIndex.map { case (a, i) =>
@@ -381,12 +381,12 @@ case class SortMergeJoin(
}
/**
- * Splits variables based on whether it's used by condition or not, returns the code to create
- * these variables before the condition and after the condition.
- *
- * Only a few columns are used by condition, then we can skip the accessing of those columns
- * that are not used by condition also filtered out by condition.
- */
+ * Splits variables based on whether it's used by condition or not, returns the code to create
+ * these variables before the condition and after the condition.
+ *
+ * Only a few columns are used by condition, then we can skip the accessing of those columns
+ * that are not used by condition also filtered out by condition.
+ */
private def splitVarsByCondition(
attributes: Seq[Attribute],
variables: Seq[ExprCode]): (String, String) = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index 582dda8603..d2ab18ef0e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -17,14 +17,84 @@
package org.apache.spark.sql.execution
+import scala.language.existentials
+
+import org.apache.spark.api.java.function.MapFunction
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection, GenerateUnsafeRowJoiner}
+import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.types.ObjectType
/**
+ * Takes the input row from child and turns it into object using the given deserializer expression.
+ * The output of this operator is a single-field safe row containing the deserialized object.
+ */
+case class DeserializeToObject(
+ deserializer: Alias,
+ child: SparkPlan) extends UnaryNode with CodegenSupport {
+ override def output: Seq[Attribute] = deserializer.toAttribute :: Nil
+
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].upstreams()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
+ child.asInstanceOf[CodegenSupport].produce(ctx, this)
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
+ val bound = ExpressionCanonicalizer.execute(
+ BindReferences.bindReference(deserializer, child.output))
+ ctx.currentVars = input
+ val resultVars = bound.gen(ctx) :: Nil
+ consume(ctx, resultVars)
+ }
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitionsInternal { iter =>
+ val projection = GenerateSafeProjection.generate(deserializer :: Nil, child.output)
+ iter.map(projection)
+ }
+ }
+}
+
+/**
+ * Takes the input object from child and turns in into unsafe row using the given serializer
+ * expression. The output of its child must be a single-field row containing the input object.
+ */
+case class SerializeFromObject(
+ serializer: Seq[NamedExpression],
+ child: SparkPlan) extends UnaryNode with CodegenSupport {
+ override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].upstreams()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
+ child.asInstanceOf[CodegenSupport].produce(ctx, this)
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
+ val bound = serializer.map { expr =>
+ ExpressionCanonicalizer.execute(BindReferences.bindReference(expr, child.output))
+ }
+ ctx.currentVars = input
+ val resultVars = bound.map(_.gen(ctx))
+ consume(ctx, resultVars)
+ }
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitionsInternal { iter =>
+ val projection = UnsafeProjection.create(serializer)
+ iter.map(projection)
+ }
+ }
+}
+
+/**
* Helper functions for physical operators that work with user defined objects.
*/
trait ObjectOperator extends SparkPlan {
@@ -68,6 +138,70 @@ case class MapPartitions(
}
/**
+ * Applies the given function to each input row and encodes the result.
+ *
+ * Note that, each serializer expression needs the result object which is returned by the given
+ * function, as input. This operator uses some tricks to make sure we only calculate the result
+ * object once. We don't use [[Project]] directly as subexpression elimination doesn't work with
+ * whole stage codegen and it's confusing to show the un-common-subexpression-eliminated version of
+ * a project while explain.
+ */
+case class MapElements(
+ func: AnyRef,
+ deserializer: Expression,
+ serializer: Seq[NamedExpression],
+ child: SparkPlan) extends UnaryNode with ObjectOperator with CodegenSupport {
+ override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].upstreams()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
+ child.asInstanceOf[CodegenSupport].produce(ctx, this)
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
+ val (funcClass, methodName) = func match {
+ case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call"
+ case _ => classOf[Any => Any] -> "apply"
+ }
+ val funcObj = Literal.create(func, ObjectType(funcClass))
+ val resultObjType = serializer.head.collect { case b: BoundReference => b }.head.dataType
+ val callFunc = Invoke(funcObj, methodName, resultObjType, Seq(deserializer))
+
+ val bound = ExpressionCanonicalizer.execute(
+ BindReferences.bindReference(callFunc, child.output))
+ ctx.currentVars = input
+ val evaluated = bound.gen(ctx)
+
+ val resultObj = LambdaVariable(evaluated.value, evaluated.isNull, resultObjType)
+ val outputFields = serializer.map(_ transform {
+ case _: BoundReference => resultObj
+ })
+ val resultVars = outputFields.map(_.gen(ctx))
+ s"""
+ ${evaluated.code}
+ ${consume(ctx, resultVars)}
+ """
+ }
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ val callFunc: Any => Any = func match {
+ case m: MapFunction[_, _] => i => m.asInstanceOf[MapFunction[Any, Any]].call(i)
+ case _ => func.asInstanceOf[Any => Any]
+ }
+ child.execute().mapPartitionsInternal { iter =>
+ val getObject = generateToObject(deserializer, child.output)
+ val outputObject = generateToRow(serializer)
+ iter.map(row => outputObject(callFunc(getObject(row))))
+ }
+ }
+
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+}
+
+/**
* Applies the given function to each input row, appending the encoded result at the end of the row.
*/
case class AppendColumns(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
index 79e4491026..c9ab40a0a9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala
@@ -18,16 +18,17 @@
package org.apache.spark.sql.execution.python
import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
import net.razorvine.pickle.{Pickler, Unpickler}
import org.apache.spark.TaskContext
-import org.apache.spark.api.python.PythonRunner
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRunner}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, JoinedRow, UnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
/**
@@ -40,11 +41,23 @@ import org.apache.spark.sql.types.{StructField, StructType}
* we drain the queue to find the original input row. Note that if the Python process is way too
* slow, this could lead to the queue growing unbounded and eventually run out of memory.
*/
-case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan)
+case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
extends SparkPlan {
def children: Seq[SparkPlan] = child :: Nil
+ private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
+ udf.children match {
+ case Seq(u: PythonUDF) =>
+ val (chained, children) = collectFunctions(u)
+ (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
+ case children =>
+ // There should not be any other UDFs, or the children can't be evaluated directly.
+ assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
+ (ChainedPythonFunctions(Seq(udf.func)), udf.children)
+ }
+ }
+
protected override def doExecute(): RDD[InternalRow] = {
val inputRDD = child.execute().map(_.copy())
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
@@ -57,17 +70,47 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
// combine input with output from Python.
val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()
- val pickle = new Pickler
- val currentRow = newMutableProjection(udf.children, child.output)()
- val fields = udf.children.map(_.dataType)
- val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray)
+ val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip
+
+ // flatten all the arguments
+ val allInputs = new ArrayBuffer[Expression]
+ val dataTypes = new ArrayBuffer[DataType]
+ val argOffsets = inputs.map { input =>
+ input.map { e =>
+ if (allInputs.exists(_.semanticEquals(e))) {
+ allInputs.indexWhere(_.semanticEquals(e))
+ } else {
+ allInputs += e
+ dataTypes += e.dataType
+ allInputs.length - 1
+ }
+ }.toArray
+ }.toArray
+ val projection = newMutableProjection(allInputs, child.output)()
+ val schema = StructType(dataTypes.map(dt => StructField("", dt)))
+ val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython)
+ // enable memo iff we serialize the row with schema (schema and class should be memorized)
+ val pickle = new Pickler(needConversion)
// Input iterator to Python: input rows are grouped so we send them in batches to Python.
// For each row, add it to the queue.
val inputIterator = iter.grouped(100).map { inputRows =>
- val toBePickled = inputRows.map { row =>
- queue.add(row)
- EvaluatePython.toJava(currentRow(row), schema)
+ val toBePickled = inputRows.map { inputRow =>
+ queue.add(inputRow)
+ val row = projection(inputRow)
+ if (needConversion) {
+ EvaluatePython.toJava(row, schema)
+ } else {
+ // fast path for these types that does not need conversion in Python
+ val fields = new Array[Any](row.numFields)
+ var i = 0
+ while (i < row.numFields) {
+ val dt = dataTypes(i)
+ fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
+ i += 1
+ }
+ fields
+ }
}.toArray
pickle.dumps(toBePickled)
}
@@ -75,22 +118,30 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
val context = TaskContext.get()
// Output iterator for results from Python.
- val outputIterator = new PythonRunner(
- udf.func,
- bufferSize,
- reuseWorker
- ).compute(inputIterator, context.partitionId(), context)
+ val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, argOffsets)
+ .compute(inputIterator, context.partitionId(), context)
val unpickle = new Unpickler
- val row = new GenericMutableRow(1)
+ val mutableRow = new GenericMutableRow(1)
val joined = new JoinedRow
+ val resultType = if (udfs.length == 1) {
+ udfs.head.dataType
+ } else {
+ StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))
+ }
val resultProj = UnsafeProjection.create(output, output)
outputIterator.flatMap { pickedResult =>
val unpickledBatch = unpickle.loads(pickedResult)
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
}.map { result =>
- row(0) = EvaluatePython.fromJava(result, udf.dataType)
+ val row = if (udfs.length == 1) {
+ // fast path for single UDF
+ mutableRow(0) = EvaluatePython.fromJava(result, resultType)
+ mutableRow
+ } else {
+ EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
+ }
resultProj(joined(queue.poll(), row))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
index da28ec4f53..3b05e29e52 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
@@ -35,26 +35,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, Generic
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
-/**
- * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple.
- */
-case class EvaluatePython(
- udf: PythonUDF,
- child: LogicalPlan,
- resultAttribute: AttributeReference)
- extends logical.UnaryNode {
-
- def output: Seq[Attribute] = child.output :+ resultAttribute
-
- // References should not include the produced attribute.
- override def references: AttributeSet = udf.references
-}
-
-
object EvaluatePython {
- def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython =
- new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)())
-
def takeAndServe(df: DataFrame, n: Int): Int = {
registerPicklers()
df.withNewExecutionId {
@@ -66,6 +47,16 @@ object EvaluatePython {
}
}
+ def needConversionInPython(dt: DataType): Boolean = dt match {
+ case DateType | TimestampType => true
+ case _: StructType => true
+ case _: UserDefinedType[_] => true
+ case ArrayType(elementType, _) => needConversionInPython(elementType)
+ case MapType(keyType, valueType, _) =>
+ needConversionInPython(keyType) || needConversionInPython(valueType)
+ case _ => false
+ }
+
/**
* Helper for converting from Catalyst type to java type suitable for Pyrolite.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 6e76e9569f..d72b3d347d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -17,63 +17,98 @@
package org.apache.spark.sql.execution.python
-import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution
+import org.apache.spark.sql.execution.SparkPlan
/**
* Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
* alone in a batch.
*
+ * Only extracts the PythonUDFs that could be evaluated in Python (the single child is PythonUDFs
+ * or all the children could be evaluated in JVM).
+ *
* This has the limitation that the input to the Python UDF is not allowed include attributes from
* multiple child operators.
*/
-private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
- // Skip EvaluatePython nodes.
- case plan: EvaluatePython => plan
+private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
- case plan: LogicalPlan if plan.resolved =>
- // Extract any PythonUDFs from the current operator.
- val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf })
- if (udfs.isEmpty) {
- // If there aren't any, we are done.
- plan
- } else {
- // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time)
- // If there is more than one, we will add another evaluation operator in a subsequent pass.
- udfs.find(_.resolved) match {
- case Some(udf) =>
- var evaluation: EvaluatePython = null
+ private def hasPythonUDF(e: Expression): Boolean = {
+ e.find(_.isInstanceOf[PythonUDF]).isDefined
+ }
- // Rewrite the child that has the input required for the UDF
- val newChildren = plan.children.map { child =>
- // Check to make sure that the UDF can be evaluated with only the input of this child.
- // Other cases are disallowed as they are ambiguous or would require a cartesian
- // product.
- if (udf.references.subsetOf(child.outputSet)) {
- evaluation = EvaluatePython(udf, child)
- evaluation
- } else if (udf.references.intersect(child.outputSet).nonEmpty) {
- sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
- } else {
- child
- }
- }
+ private def canEvaluateInPython(e: PythonUDF): Boolean = {
+ e.children match {
+ // single PythonUDF child could be chained and evaluated in Python
+ case Seq(u: PythonUDF) => canEvaluateInPython(u)
+ // Python UDF can't be evaluated directly in JVM
+ case children => !children.exists(hasPythonUDF)
+ }
+ }
- assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.")
+ private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match {
+ case udf: PythonUDF if canEvaluateInPython(udf) => Seq(udf)
+ case e => e.children.flatMap(collectEvaluatableUDF)
+ }
- // Trim away the new UDF value if it was only used for filtering or something.
- logical.Project(
- plan.output,
- plan.transformExpressions {
- case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute
- }.withNewChildren(newChildren))
+ def apply(plan: SparkPlan): SparkPlan = plan transformUp {
+ case plan: SparkPlan => extract(plan)
+ }
- case None =>
- // If there is no Python UDF that is resolved, skip this round.
- plan
+ /**
+ * Extract all the PythonUDFs from the current operator.
+ */
+ def extract(plan: SparkPlan): SparkPlan = {
+ val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
+ if (udfs.isEmpty) {
+ // If there aren't any, we are done.
+ plan
+ } else {
+ val attributeMap = mutable.HashMap[PythonUDF, Expression]()
+ // Rewrite the child that has the input required for the UDF
+ val newChildren = plan.children.map { child =>
+ // Pick the UDF we are going to evaluate
+ val validUdfs = udfs.filter { case udf =>
+ // Check to make sure that the UDF can be evaluated with only the input of this child.
+ udf.references.subsetOf(child.outputSet)
}
+ if (validUdfs.nonEmpty) {
+ val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
+ AttributeReference(s"pythonUDF$i", u.dataType)()
+ }
+ val evaluation = BatchPythonEvaluation(validUdfs, child.output ++ resultAttrs, child)
+ attributeMap ++= validUdfs.zip(resultAttrs)
+ evaluation
+ } else {
+ child
+ }
+ }
+ // Other cases are disallowed as they are ambiguous or would require a cartesian
+ // product.
+ udfs.filterNot(attributeMap.contains).foreach { udf =>
+ if (udf.references.subsetOf(plan.inputSet)) {
+ sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
+ } else {
+ sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.")
+ }
+ }
+
+ val rewritten = plan.transformExpressions {
+ case p: PythonUDF if attributeMap.contains(p) =>
+ attributeMap(p)
+ }.withNewChildren(newChildren)
+
+ // extract remaining python UDFs recursively
+ val newPlan = extract(rewritten)
+ if (newPlan.output != plan.output) {
+ // Trim away the new UDF value if it was only used for filtering or something.
+ execution.Project(plan.output, newPlan)
+ } else {
+ newPlan
}
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
index 4f1b837158..59d7e8dd6d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.execution.python
import org.apache.spark.api.python.PythonFunction
-import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable}
import org.apache.spark.sql.types.DataType
@@ -30,7 +29,7 @@ case class PythonUDF(
func: PythonFunction,
dataType: DataType,
children: Seq[Expression])
- extends Expression with Unevaluable with NonSQLExpression with Logging {
+ extends Expression with Unevaluable with NonSQLExpression {
override def toString: String = s"$name(${children.mkString(", ")})"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index e0b6709c51..d603f63a08 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -296,7 +296,7 @@ private[sql] object StatFunctions extends Logging {
val defaultRelativeError: Double = 0.01
/**
- * Statisttics from the Greenwald-Khanna paper.
+ * Statistics from the Greenwald-Khanna paper.
* @param value the sampled value
* @param g the minimum rank jump from the previous value's minimum rank
* @param delta the maximum span of the rank.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
index e819e95d61..6921ae584d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
@@ -32,7 +32,7 @@ object FileStreamSink {
/**
* A sink that writes out results to parquet files. Each batch is written out to a unique
- * directory. After all of the files in a batch have been succesfully written, the list of
+ * directory. After all of the files in a batch have been successfully written, the list of
* file paths is appended to the log atomically. In the case of partial failures, some duplicate
* data may be present in the target directory, but only one copy of each file will be present
* in the log.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
new file mode 100644
index 0000000000..aaced49dd1
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -0,0 +1,72 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryNode}
+
+/**
+ * A variant of [[QueryExecution]] that allows the execution of the given [[LogicalPlan]]
+ * plan incrementally. Possibly preserving state in between each execution.
+ */
+class IncrementalExecution(
+ ctx: SQLContext,
+ logicalPlan: LogicalPlan,
+ checkpointLocation: String,
+ currentBatchId: Long) extends QueryExecution(ctx, logicalPlan) {
+
+ // TODO: make this always part of planning.
+ val stateStrategy = sqlContext.sessionState.planner.StatefulAggregationStrategy :: Nil
+
+ // Modified planner with stateful operations.
+ override def planner: SparkPlanner =
+ new SparkPlanner(
+ sqlContext.sparkContext,
+ sqlContext.conf,
+ stateStrategy)
+
+ /**
+ * Records the current id for a given stateful operator in the query plan as the `state`
+ * preperation walks the query plan.
+ */
+ private var operatorId = 0
+
+ /** Locates save/restore pairs surrounding aggregation. */
+ val state = new Rule[SparkPlan] {
+ override def apply(plan: SparkPlan): SparkPlan = plan transform {
+ case StateStoreSave(keys, None,
+ UnaryNode(agg,
+ StateStoreRestore(keys2, None, child))) =>
+ val stateId = OperatorStateId(checkpointLocation, operatorId, currentBatchId - 1)
+ operatorId += 1
+
+ StateStoreSave(
+ keys,
+ Some(stateId),
+ agg.withNewChildren(
+ StateStoreRestore(
+ keys,
+ Some(stateId),
+ child) :: Nil))
+ }
+ }
+
+ override def preparations: Seq[Rule[SparkPlan]] = state +: super.preparations
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala
new file mode 100644
index 0000000000..595774761c
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.execution
+import org.apache.spark.sql.execution.streaming.state._
+import org.apache.spark.sql.execution.SparkPlan
+
+/** Used to identify the state store for a given operator. */
+case class OperatorStateId(
+ checkpointLocation: String,
+ operatorId: Long,
+ batchId: Long)
+
+/**
+ * An operator that saves or restores state from the [[StateStore]]. The [[OperatorStateId]] should
+ * be filled in by `prepareForExecution` in [[IncrementalExecution]].
+ */
+trait StatefulOperator extends SparkPlan {
+ def stateId: Option[OperatorStateId]
+
+ protected def getStateId: OperatorStateId = attachTree(this) {
+ stateId.getOrElse {
+ throw new IllegalStateException("State location not present for execution")
+ }
+ }
+}
+
+/**
+ * For each input tuple, the key is calculated and the value from the [[StateStore]] is added
+ * to the stream (in addition to the input tuple) if present.
+ */
+case class StateStoreRestore(
+ keyExpressions: Seq[Attribute],
+ stateId: Option[OperatorStateId],
+ child: SparkPlan) extends execution.UnaryNode with StatefulOperator {
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitionsWithStateStore(
+ getStateId.checkpointLocation,
+ operatorId = getStateId.operatorId,
+ storeVersion = getStateId.batchId,
+ keyExpressions.toStructType,
+ child.output.toStructType,
+ new StateStoreConf(sqlContext.conf),
+ Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
+ val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
+ iter.flatMap { row =>
+ val key = getKey(row)
+ val savedState = store.get(key)
+ row +: savedState.toSeq
+ }
+ }
+ }
+ override def output: Seq[Attribute] = child.output
+}
+
+/**
+ * For each input tuple, the key is calculated and the tuple is `put` into the [[StateStore]].
+ */
+case class StateStoreSave(
+ keyExpressions: Seq[Attribute],
+ stateId: Option[OperatorStateId],
+ child: SparkPlan) extends execution.UnaryNode with StatefulOperator {
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitionsWithStateStore(
+ getStateId.checkpointLocation,
+ operatorId = getStateId.operatorId,
+ storeVersion = getStateId.batchId,
+ keyExpressions.toStructType,
+ child.output.toStructType,
+ new StateStoreConf(sqlContext.conf),
+ Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
+ new Iterator[InternalRow] {
+ private[this] val baseIterator = iter
+ private[this] val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
+
+ override def hasNext: Boolean = {
+ if (!baseIterator.hasNext) {
+ store.commit()
+ false
+ } else {
+ true
+ }
+ }
+
+ override def next(): InternalRow = {
+ val row = baseIterator.next().asInstanceOf[UnsafeRow]
+ val key = getKey(row)
+ store.put(key.copy(), row.copy())
+ row
+ }
+ }
+ }
+ }
+
+ override def output: Seq[Attribute] = child.output
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 60e00d203c..87dd27a2b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.streaming
import java.util.concurrent.{CountDownLatch, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger
-import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
@@ -28,12 +27,14 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.util.ContinuousQueryListener
import org.apache.spark.sql.util.ContinuousQueryListener._
+import org.apache.spark.util.UninterruptibleThread
/**
* Manages the execution of a streaming Spark SQL query that is occurring in a separate thread.
@@ -42,30 +43,30 @@ import org.apache.spark.sql.util.ContinuousQueryListener._
* and the results are committed transactionally to the given [[Sink]].
*/
class StreamExecution(
- val sqlContext: SQLContext,
+ override val sqlContext: SQLContext,
override val name: String,
- val checkpointRoot: String,
+ checkpointRoot: String,
private[sql] val logicalPlan: LogicalPlan,
- val sink: Sink) extends ContinuousQuery with Logging {
+ val sink: Sink,
+ val trigger: Trigger) extends ContinuousQuery with Logging {
/** An monitor used to wait/notify when batches complete. */
private val awaitBatchLock = new Object
private val startLatch = new CountDownLatch(1)
private val terminationLatch = new CountDownLatch(1)
- /** Minimum amount of time in between the start of each batch. */
- private val minBatchTime = 10
-
/**
* Tracks how much data we have processed and committed to the sink or state store from each
* input source.
*/
+ @volatile
private[sql] var committedOffsets = new StreamProgress
/**
* Tracks the offsets that are available to be processed, but have not yet be committed to the
* sink.
*/
+ @volatile
private var availableOffsets = new StreamProgress
/** The current batchId or -1 if execution has not yet been initialized. */
@@ -73,11 +74,15 @@ class StreamExecution(
/** All stream sources present the query plan. */
private val sources =
- logicalPlan.collect { case s: StreamingRelation => s.source }
+ logicalPlan.collect { case s: StreamingExecutionRelation => s.source }
/** A list of unique sources in the query plan. */
private val uniqueSources = sources.distinct
+ private val triggerExecutor = trigger match {
+ case t: ProcessingTime => ProcessingTimeExecutor(t)
+ }
+
/** Defines the internal state of execution */
@volatile
private var state: State = INITIALIZED
@@ -89,9 +94,10 @@ class StreamExecution(
private[sql] var streamDeathCause: ContinuousQueryException = null
/** The thread that runs the micro-batches of this stream. */
- private[sql] val microBatchThread = new Thread(s"stream execution thread for $name") {
- override def run(): Unit = { runBatches() }
- }
+ private[sql] val microBatchThread =
+ new UninterruptibleThread(s"stream execution thread for $name") {
+ override def run(): Unit = { runBatches() }
+ }
/**
* A write-ahead-log that records the offsets that are present in each batch. In order to ensure
@@ -102,71 +108,13 @@ class StreamExecution(
private val offsetLog =
new HDFSMetadataLog[CompositeOffset](sqlContext, checkpointFile("offsets"))
- /** A monitor to protect "uninterruptible" and "interrupted" */
- private val uninterruptibleLock = new Object
-
- /**
- * Indicates if "microBatchThread" are in the uninterruptible status. If so, interrupting
- * "microBatchThread" will be deferred until "microBatchThread" enters into the interruptible
- * status.
- */
- @GuardedBy("uninterruptibleLock")
- private var uninterruptible = false
-
- /**
- * Indicates if we should interrupt "microBatchThread" when we are leaving the uninterruptible
- * zone.
- */
- @GuardedBy("uninterruptibleLock")
- private var shouldInterruptThread = false
-
- /**
- * Interrupt "microBatchThread" if possible. If "microBatchThread" is in the uninterruptible
- * status, "microBatchThread" won't be interrupted until it enters into the interruptible status.
- */
- private def interruptMicroBatchThreadSafely(): Unit = {
- uninterruptibleLock.synchronized {
- if (uninterruptible) {
- shouldInterruptThread = true
- } else {
- microBatchThread.interrupt()
- }
- }
- }
-
- /**
- * Run `f` uninterruptibly in "microBatchThread". "microBatchThread" won't be interrupted before
- * returning from `f`.
- */
- private def runUninterruptiblyInMicroBatchThread[T](f: => T): T = {
- assert(Thread.currentThread() == microBatchThread)
- uninterruptibleLock.synchronized {
- uninterruptible = true
- // Clear the interrupted status if it's set.
- if (Thread.interrupted()) {
- shouldInterruptThread = true
- }
- }
- try {
- f
- } finally {
- uninterruptibleLock.synchronized {
- uninterruptible = false
- if (shouldInterruptThread) {
- // Recover the interrupted status
- microBatchThread.interrupt()
- shouldInterruptThread = false
- }
- }
- }
- }
-
/** Whether the query is currently active or not */
override def isActive: Boolean = state == ACTIVE
/** Returns current status of all the sources. */
override def sourceStatuses: Array[SourceStatus] = {
- sources.map(s => new SourceStatus(s.toString, availableOffsets.get(s))).toArray
+ val localAvailableOffsets = availableOffsets
+ sources.map(s => new SourceStatus(s.toString, localAvailableOffsets.get(s))).toArray
}
/** Returns current status of the sink. */
@@ -211,11 +159,15 @@ class StreamExecution(
SQLContext.setActive(sqlContext)
populateStartOffsets()
logDebug(s"Stream running from $committedOffsets to $availableOffsets")
- while (isActive) {
- if (dataAvailable) runBatch()
- commitAndConstructNextBatch()
- Thread.sleep(minBatchTime) // TODO: Could be tighter
- }
+ triggerExecutor.execute(() => {
+ if (isActive) {
+ if (dataAvailable) runBatch()
+ constructNextBatch()
+ true
+ } else {
+ false
+ }
+ })
} catch {
case _: InterruptedException if state == TERMINATED => // interrupted by stop()
case NonFatal(e) =>
@@ -258,7 +210,7 @@ class StreamExecution(
case None => // We are starting this stream for the first time.
logInfo(s"Starting new continuous query.")
currentBatchId = 0
- commitAndConstructNextBatch()
+ constructNextBatch()
}
}
@@ -278,15 +230,8 @@ class StreamExecution(
/**
* Queries all of the sources to see if any new data is available. When there is new data the
* batchId counter is incremented and a new log entry is written with the newest offsets.
- *
- * Note that committing the offsets for a new batch implicitly marks the previous batch as
- * finished and thus this method should only be called when all currently available data
- * has been written to the sink.
*/
- private def commitAndConstructNextBatch(): Boolean = {
- // Update committed offsets.
- committedOffsets ++= availableOffsets
-
+ private def constructNextBatch(): Unit = {
// There is a potential dead-lock in Hadoop "Shell.runCommand" before 2.5.0 (HADOOP-10622).
// If we interrupt some thread running Shell.runCommand, we may hit this issue.
// As "FileStreamSource.getOffset" will create a file using HDFS API and call "Shell.runCommand"
@@ -294,33 +239,37 @@ class StreamExecution(
// method. See SPARK-14131.
//
// Check to see what new data is available.
- val newData = runUninterruptiblyInMicroBatchThread {
+ val newData = microBatchThread.runUninterruptibly {
uniqueSources.flatMap(s => s.getOffset.map(o => s -> o))
}
availableOffsets ++= newData
- if (dataAvailable) {
+ val hasNewData = awaitBatchLock.synchronized {
+ if (dataAvailable) {
+ true
+ } else {
+ noNewData = true
+ false
+ }
+ }
+ if (hasNewData) {
// There is a potential dead-lock in Hadoop "Shell.runCommand" before 2.5.0 (HADOOP-10622).
// If we interrupt some thread running Shell.runCommand, we may hit this issue.
// As "offsetLog.add" will create a file using HDFS API and call "Shell.runCommand" to set
// the file permission, we should not interrupt "microBatchThread" when running this method.
// See SPARK-14131.
- runUninterruptiblyInMicroBatchThread {
+ microBatchThread.runUninterruptibly {
assert(
offsetLog.add(currentBatchId, availableOffsets.toCompositeOffset(sources)),
s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId")
}
currentBatchId += 1
logInfo(s"Committed offsets for batch $currentBatchId.")
- true
} else {
- noNewData = true
awaitBatchLock.synchronized {
// Wake up any threads that are waiting for the stream to progress.
awaitBatchLock.notifyAll()
}
-
- false
}
}
@@ -330,6 +279,8 @@ class StreamExecution(
private def runBatch(): Unit = {
val startTime = System.nanoTime()
+ // TODO: Move this to IncrementalExecution.
+
// Request unprocessed data from all sources.
val newData = availableOffsets.flatMap {
case (source, available) if committedOffsets.get(source).map(_ < available).getOrElse(true) =>
@@ -344,7 +295,7 @@ class StreamExecution(
var replacements = new ArrayBuffer[(Attribute, Attribute)]
// Replace sources in the logical plan with data that has arrived since the last batch.
val withNewSources = logicalPlan transform {
- case StreamingRelation(source, output) =>
+ case StreamingExecutionRelation(source, output) =>
newData.get(source).map { data =>
val newPlan = data.logicalPlan
assert(output.size == newPlan.output.size,
@@ -363,13 +314,14 @@ class StreamExecution(
}
val optimizerStart = System.nanoTime()
-
- lastExecution = new QueryExecution(sqlContext, newPlan)
- val executedPlan = lastExecution.executedPlan
+ lastExecution =
+ new IncrementalExecution(sqlContext, newPlan, checkpointFile("state"), currentBatchId)
+ lastExecution.executedPlan
val optimizerTime = (System.nanoTime() - optimizerStart).toDouble / 1000000
logDebug(s"Optimized batch in ${optimizerTime}ms")
- val nextBatch = Dataset.ofRows(sqlContext, newPlan)
+ val nextBatch =
+ new Dataset(sqlContext, lastExecution, RowEncoder(lastExecution.analyzed.schema))
sink.addBatch(currentBatchId - 1, nextBatch)
awaitBatchLock.synchronized {
@@ -379,6 +331,8 @@ class StreamExecution(
val batchTime = (System.nanoTime() - startTime).toDouble / 1000000
logInfo(s"Completed up to $availableOffsets in ${batchTime}ms")
+ // Update committed offsets.
+ committedOffsets ++= availableOffsets
postEvent(new QueryProgress(this))
}
@@ -395,7 +349,7 @@ class StreamExecution(
// intentionally
state = TERMINATED
if (microBatchThread.isAlive) {
- interruptMicroBatchThreadSafely()
+ microBatchThread.interrupt()
microBatchThread.join()
}
logInfo(s"Query $name was stopped")
@@ -406,7 +360,10 @@ class StreamExecution(
* least the given `Offset`. This method is indented for use primarily when writing tests.
*/
def awaitOffset(source: Source, newOffset: Offset): Unit = {
- def notDone = !committedOffsets.contains(source) || committedOffsets(source) < newOffset
+ def notDone = {
+ val localCommittedOffsets = committedOffsets
+ !localCommittedOffsets.contains(source) || localCommittedOffsets(source) < newOffset
+ }
while (notDone) {
logInfo(s"Waiting until $newOffset at $source")
@@ -418,13 +375,17 @@ class StreamExecution(
/** A flag to indicate that a batch has completed with no new data available. */
@volatile private var noNewData = false
- override def processAllAvailable(): Unit = {
+ override def processAllAvailable(): Unit = awaitBatchLock.synchronized {
noNewData = false
- while (!noNewData) {
- awaitBatchLock.synchronized { awaitBatchLock.wait(10000) }
- if (streamDeathCause != null) { throw streamDeathCause }
+ while (true) {
+ awaitBatchLock.wait(10000)
+ if (streamDeathCause != null) {
+ throw streamDeathCause
+ }
+ if (noNewData) {
+ return
+ }
}
- if (streamDeathCause != null) { throw streamDeathCause }
}
override def awaitTermination(): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
index e35c444348..d2872e49ce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
@@ -19,16 +19,37 @@ package org.apache.spark.sql.execution.streaming
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LeafNode
+import org.apache.spark.sql.execution.datasources.DataSource
object StreamingRelation {
- def apply(source: Source): StreamingRelation =
- StreamingRelation(source, source.schema.toAttributes)
+ def apply(dataSource: DataSource): StreamingRelation = {
+ val (name, schema) = dataSource.sourceSchema()
+ StreamingRelation(dataSource, name, schema.toAttributes)
+ }
+}
+
+/**
+ * Used to link a streaming [[DataSource]] into a
+ * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. This is only used for creating
+ * a streaming [[org.apache.spark.sql.DataFrame]] from [[org.apache.spark.sql.DataFrameReader]].
+ * It should be used to create [[Source]] and converted to [[StreamingExecutionRelation]] when
+ * passing to [StreamExecution]] to run a query.
+ */
+case class StreamingRelation(dataSource: DataSource, sourceName: String, output: Seq[Attribute])
+ extends LeafNode {
+ override def toString: String = sourceName
}
/**
* Used to link a streaming [[Source]] of data into a
* [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]].
*/
-case class StreamingRelation(source: Source, output: Seq[Attribute]) extends LeafNode {
+case class StreamingExecutionRelation(source: Source, output: Seq[Attribute]) extends LeafNode {
override def toString: String = source.toString
}
+
+object StreamingExecutionRelation {
+ def apply(source: Source): StreamingExecutionRelation = {
+ StreamingExecutionRelation(source, source.schema.toAttributes)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala
new file mode 100644
index 0000000000..a1132d5106
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.ProcessingTime
+import org.apache.spark.util.{Clock, SystemClock}
+
+trait TriggerExecutor {
+
+ /**
+ * Execute batches using `batchRunner`. If `batchRunner` runs `false`, terminate the execution.
+ */
+ def execute(batchRunner: () => Boolean): Unit
+}
+
+/**
+ * A trigger executor that runs a batch every `intervalMs` milliseconds.
+ */
+case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock = new SystemClock())
+ extends TriggerExecutor with Logging {
+
+ private val intervalMs = processingTime.intervalMs
+
+ override def execute(batchRunner: () => Boolean): Unit = {
+ while (true) {
+ val batchStartTimeMs = clock.getTimeMillis()
+ val terminated = !batchRunner()
+ if (intervalMs > 0) {
+ val batchEndTimeMs = clock.getTimeMillis()
+ val batchElapsedTimeMs = batchEndTimeMs - batchStartTimeMs
+ if (batchElapsedTimeMs > intervalMs) {
+ notifyBatchFallingBehind(batchElapsedTimeMs)
+ }
+ if (terminated) {
+ return
+ }
+ clock.waitTillTime(nextBatchTime(batchEndTimeMs))
+ } else {
+ if (terminated) {
+ return
+ }
+ }
+ }
+ }
+
+ /** Called when a batch falls behind. Expose for test only */
+ def notifyBatchFallingBehind(realElapsedTimeMs: Long): Unit = {
+ logWarning("Current batch is falling behind. The trigger interval is " +
+ s"${intervalMs} milliseconds, but spent ${realElapsedTimeMs} milliseconds")
+ }
+
+ /** Return the next multiple of intervalMs */
+ def nextBatchTime(now: Long): Long = {
+ (now - 1) / intervalMs * intervalMs + intervalMs
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 0f91e59e04..3820968324 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -18,15 +18,16 @@
package org.apache.spark.sql.execution.streaming
import java.util.concurrent.atomic.AtomicInteger
+import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
-import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, SQLContext}
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, RowEncoder}
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.encoderFor
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.LeafNode
import org.apache.spark.sql.types.StructType
object MemoryStream {
@@ -45,10 +46,13 @@ object MemoryStream {
case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
extends Source with Logging {
protected val encoder = encoderFor[A]
- protected val logicalPlan = StreamingRelation(this)
+ protected val logicalPlan = StreamingExecutionRelation(this)
protected val output = logicalPlan.output
+
+ @GuardedBy("this")
protected val batches = new ArrayBuffer[Dataset[A]]
+ @GuardedBy("this")
protected var currentOffset: LongOffset = new LongOffset(-1)
def schema: StructType = encoder.schema
@@ -67,10 +71,10 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
def addData(data: TraversableOnce[A]): Offset = {
import sqlContext.implicits._
+ val ds = data.toVector.toDS()
+ logDebug(s"Adding ds: $ds")
this.synchronized {
currentOffset = currentOffset + 1
- val ds = data.toVector.toDS()
- logDebug(s"Adding ds: $ds")
batches.append(ds)
currentOffset
}
@@ -78,10 +82,12 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
override def toString: String = s"MemoryStream[${output.mkString(",")}]"
- override def getOffset: Option[Offset] = if (batches.isEmpty) {
- None
- } else {
- Some(currentOffset)
+ override def getOffset: Option[Offset] = synchronized {
+ if (batches.isEmpty) {
+ None
+ } else {
+ Some(currentOffset)
+ }
}
/**
@@ -91,7 +97,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
val startOrdinal =
start.map(_.asInstanceOf[LongOffset]).getOrElse(LongOffset(-1)).offset.toInt + 1
val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1
- val newBlocks = batches.slice(startOrdinal, endOrdinal)
+ val newBlocks = synchronized { batches.slice(startOrdinal, endOrdinal) }
logDebug(
s"MemoryBatch [$startOrdinal, $endOrdinal]: ${newBlocks.flatMap(_.collect()).mkString(", ")}")
@@ -108,8 +114,9 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
* A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
* tests and does not provide durability.
*/
-class MemorySink(schema: StructType) extends Sink with Logging {
+class MemorySink(val schema: StructType) extends Sink with Logging {
/** An order list of batches that have been written to this [[Sink]]. */
+ @GuardedBy("this")
private val batches = new ArrayBuffer[Array[Row]]()
/** Returns all rows that are stored in this [[Sink]]. */
@@ -117,6 +124,8 @@ class MemorySink(schema: StructType) extends Sink with Logging {
batches.flatten
}
+ def lastBatch: Seq[Row] = synchronized { batches.last }
+
def toDebugString: String = synchronized {
batches.zipWithIndex.map { case (b, i) =>
val dataStr = try b.mkString(" ") catch {
@@ -126,7 +135,7 @@ class MemorySink(schema: StructType) extends Sink with Logging {
}.mkString("\n")
}
- override def addBatch(batchId: Long, data: DataFrame): Unit = {
+ override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized {
if (batchId == batches.size) {
logDebug(s"Committing batch $batchId")
batches.append(data.collect())
@@ -136,3 +145,9 @@ class MemorySink(schema: StructType) extends Sink with Logging {
}
}
+/**
+ * Used to query the data that has been written into a [[MemorySink]].
+ */
+case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode {
+ def this(sink: MemorySink) = this(sink, sink.schema.toAttributes)
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index ee015baf3f..3335755fd3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -46,12 +46,14 @@ import org.apache.spark.util.Utils
* Usage:
* To update the data in the state store, the following order of operations are needed.
*
- * - val store = StateStore.get(operatorId, partitionId, version) // to get the right store
- * - store.update(...)
+ * // get the right store
+ * - val store = StateStore.get(
+ * StateStoreId(checkpointLocation, operatorId, partitionId), ..., version, ...)
+ * - store.put(...)
* - store.remove(...)
- * - store.commit() // commits all the updates to made with version number
+ * - store.commit() // commits all the updates to made; the new version will be returned
* - store.iterator() // key-value data after last commit as an iterator
- * - store.updates() // updates made in the last as an iterator
+ * - store.updates() // updates made in the last commit as an iterator
*
* Fault-tolerance model:
* - Every set of updates is written to a delta file before committing.
@@ -81,7 +83,7 @@ private[state] class HDFSBackedStateStoreProvider(
trait STATE
case object UPDATING extends STATE
case object COMMITTED extends STATE
- case object CANCELLED extends STATE
+ case object ABORTED extends STATE
private val newVersion = version + 1
private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}")
@@ -94,15 +96,14 @@ private[state] class HDFSBackedStateStoreProvider(
override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id
- /**
- * Update the value of a key using the value generated by the update function.
- * @note Do not mutate the retrieved value row as it will unexpectedly affect the previous
- * versions of the store data.
- */
- override def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit = {
- verify(state == UPDATING, "Cannot update after already committed or cancelled")
- val oldValueOption = Option(mapToUpdate.get(key))
- val value = updateFunc(oldValueOption)
+ override def get(key: UnsafeRow): Option[UnsafeRow] = {
+ Option(mapToUpdate.get(key))
+ }
+
+ override def put(key: UnsafeRow, value: UnsafeRow): Unit = {
+ verify(state == UPDATING, "Cannot remove after already committed or aborted")
+
+ val isNewKey = !mapToUpdate.containsKey(key)
mapToUpdate.put(key, value)
Option(allUpdates.get(key)) match {
@@ -110,13 +111,12 @@ private[state] class HDFSBackedStateStoreProvider(
// Value did not exist in previous version and was added already, keep it marked as added
allUpdates.put(key, ValueAdded(key, value))
case Some(ValueUpdated(_, _)) | Some(KeyRemoved(_)) =>
- // Value existed in prev version and updated/removed, mark it as updated
+ // Value existed in previous version and updated/removed, mark it as updated
allUpdates.put(key, ValueUpdated(key, value))
case None =>
// There was no prior update, so mark this as added or updated according to its presence
// in previous version.
- val update =
- if (oldValueOption.nonEmpty) ValueUpdated(key, value) else ValueAdded(key, value)
+ val update = if (isNewKey) ValueAdded(key, value) else ValueUpdated(key, value)
allUpdates.put(key, update)
}
writeToDeltaFile(tempDeltaFileStream, ValueUpdated(key, value))
@@ -124,7 +124,7 @@ private[state] class HDFSBackedStateStoreProvider(
/** Remove keys that match the following condition */
override def remove(condition: UnsafeRow => Boolean): Unit = {
- verify(state == UPDATING, "Cannot remove after already committed or cancelled")
+ verify(state == UPDATING, "Cannot remove after already committed or aborted")
val keyIter = mapToUpdate.keySet().iterator()
while (keyIter.hasNext) {
val key = keyIter.next
@@ -148,7 +148,7 @@ private[state] class HDFSBackedStateStoreProvider(
/** Commit all the updates that have been made to the store, and return the new version. */
override def commit(): Long = {
- verify(state == UPDATING, "Cannot commit again after already committed or cancelled")
+ verify(state == UPDATING, "Cannot commit after already committed or aborted")
try {
finalizeDeltaFile(tempDeltaFileStream)
@@ -163,40 +163,44 @@ private[state] class HDFSBackedStateStoreProvider(
}
}
- /** Cancel all the updates made on this store. This store will not be usable any more. */
- override def cancel(): Unit = {
- state = CANCELLED
+ /** Abort all the updates made on this store. This store will not be usable any more. */
+ override def abort(): Unit = {
+ verify(state == UPDATING || state == ABORTED, "Cannot abort after already committed")
+
+ state = ABORTED
if (tempDeltaFileStream != null) {
tempDeltaFileStream.close()
}
if (tempDeltaFile != null && fs.exists(tempDeltaFile)) {
fs.delete(tempDeltaFile, true)
}
- logInfo("Canceled ")
+ logInfo("Aborted")
}
/**
- * Get an iterator of all the store data. This can be called only after committing the
- * updates.
+ * Get an iterator of all the store data.
+ * This can be called only after committing all the updates made in the current thread.
*/
override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = {
- verify(state == COMMITTED, "Cannot get iterator of store data before comitting")
+ verify(state == COMMITTED,
+ "Cannot get iterator of store data before committing or after aborting")
HDFSBackedStateStoreProvider.this.iterator(newVersion)
}
/**
* Get an iterator of all the updates made to the store in the current version.
- * This can be called only after committing the updates.
+ * This can be called only after committing all the updates made in the current thread.
*/
override def updates(): Iterator[StoreUpdate] = {
- verify(state == COMMITTED, "Cannot get iterator of updates before committing")
+ verify(state == COMMITTED,
+ "Cannot get iterator of updates before committing or after aborting")
allUpdates.values().asScala.toIterator
}
/**
* Whether all updates have been committed
*/
- override def hasCommitted: Boolean = {
+ override private[state] def hasCommitted: Boolean = {
state == COMMITTED
}
}
@@ -225,7 +229,7 @@ private[state] class HDFSBackedStateStoreProvider(
}
override def toString(): String = {
- s"StateStore[id = (op=${id.operatorId},part=${id.partitionId}), dir = $baseDir]"
+ s"StateStore[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]"
}
/* Internal classes and methods */
@@ -279,7 +283,7 @@ private[state] class HDFSBackedStateStoreProvider(
} else {
if (!fs.isDirectory(baseDir)) {
throw new IllegalStateException(
- s"Cannot use ${id.checkpointLocation} for storing state data for $this as" +
+ s"Cannot use ${id.checkpointLocation} for storing state data for $this as " +
s"$baseDir already exists and is not a directory")
}
}
@@ -455,11 +459,11 @@ private[state] class HDFSBackedStateStoreProvider(
filesForVersion(files, lastVersion).filter(_.isSnapshot == false)
synchronized { loadedMaps.get(lastVersion) } match {
case Some(map) =>
- if (deltaFilesForLastVersion.size > storeConf.maxDeltasForSnapshot) {
+ if (deltaFilesForLastVersion.size > storeConf.minDeltasForSnapshot) {
writeSnapshotFile(lastVersion, map)
}
case None =>
- // The last map is not loaded, probably some other instance is incharge
+ // The last map is not loaded, probably some other instance is in charge
}
}
@@ -470,10 +474,10 @@ private[state] class HDFSBackedStateStoreProvider(
}
/**
- * Clean up old snapshots and delta files that are not needed any more. It ensures that last
- * few versions of the store can be recovered from the files, so re-executed RDD operations
- * can re-apply updates on the past versions of the store.
- */
+ * Clean up old snapshots and delta files that are not needed any more. It ensures that last
+ * few versions of the store can be recovered from the files, so re-executed RDD operations
+ * can re-apply updates on the past versions of the store.
+ */
private[state] def cleanup(): Unit = {
try {
val files = fetchFiles()
@@ -508,7 +512,6 @@ private[state] class HDFSBackedStateStoreProvider(
.lastOption
val deltaBatchFiles = latestSnapshotFileBeforeVersion match {
case Some(snapshotFile) =>
- val deltaBatchIds = (snapshotFile.version + 1) to version
val deltaFiles = allFiles.filter { file =>
file.version > snapshotFile.version && file.version <= version
@@ -581,4 +584,3 @@ private[state] class HDFSBackedStateStoreProvider(
}
}
}
-
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index ca5c864d9e..9521506325 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.execution.streaming.state
-import java.util.Timer
import java.util.concurrent.{ScheduledFuture, TimeUnit}
import scala.collection.mutable
@@ -47,12 +46,11 @@ trait StateStore {
/** Version of the data in this store before committing updates. */
def version: Long
- /**
- * Update the value of a key using the value generated by the update function.
- * @note Do not mutate the retrieved value row as it will unexpectedly affect the previous
- * versions of the store data.
- */
- def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit
+ /** Get the current value of a key. */
+ def get(key: UnsafeRow): Option[UnsafeRow]
+
+ /** Put a new value for a key. */
+ def put(key: UnsafeRow, value: UnsafeRow): Unit
/**
* Remove keys that match the following condition.
@@ -64,25 +62,25 @@ trait StateStore {
*/
def commit(): Long
- /** Cancel all the updates that have been made to the store. */
- def cancel(): Unit
+ /** Abort all the updates that have been made to the store. */
+ def abort(): Unit
/**
* Iterator of store data after a set of updates have been committed.
- * This can be called only after commitUpdates() has been called in the current thread.
+ * This can be called only after committing all the updates made in the current thread.
*/
def iterator(): Iterator[(UnsafeRow, UnsafeRow)]
/**
* Iterator of the updates that have been committed.
- * This can be called only after commitUpdates() has been called in the current thread.
+ * This can be called only after committing all the updates made in the current thread.
*/
def updates(): Iterator[StoreUpdate]
/**
* Whether all updates have been committed
*/
- def hasCommitted: Boolean
+ private[state] def hasCommitted: Boolean
}
@@ -110,8 +108,8 @@ case class KeyRemoved(key: UnsafeRow) extends StoreUpdate
/**
* Companion object to [[StateStore]] that provides helper methods to create and retrieve stores
* by their unique ids. In addition, when a SparkContext is active (i.e. SparkEnv.get is not null),
- * it also runs a periodic background tasks to do maintenance on the loaded stores. For each
- * store, tt uses the [[StateStoreCoordinator]] to ensure whether the current loaded instance of
+ * it also runs a periodic background task to do maintenance on the loaded stores. For each
+ * store, it uses the [[StateStoreCoordinator]] to ensure whether the current loaded instance of
* the store is the active instance. Accordingly, it either keeps it loaded and performs
* maintenance, or unloads the store.
*/
@@ -221,7 +219,7 @@ private[state] object StateStore extends Logging {
val executorId = SparkEnv.get.blockManager.blockManagerId.executorId
val verified =
coordinatorRef.map(_.verifyIfInstanceActive(storeId, executorId)).getOrElse(false)
- logDebug(s"Verifyied whether the loaded instance $storeId is active: $verified" )
+ logDebug(s"Verified whether the loaded instance $storeId is active: $verified" )
verified
} catch {
case NonFatal(e) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
index cca22a0af8..e55f63a6c8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
@@ -20,18 +20,17 @@ package org.apache.spark.sql.execution.streaming.state
import org.apache.spark.sql.internal.SQLConf
/** A class that contains configuration parameters for [[StateStore]]s. */
-private[state] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable {
+private[streaming] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable {
def this() = this(new SQLConf)
import SQLConf._
- val maxDeltasForSnapshot = conf.getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT)
+ val minDeltasForSnapshot = conf.getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT)
val minVersionsToRetain = conf.getConf(STATE_STORE_MIN_VERSIONS_TO_RETAIN)
}
-private[state] object StateStoreConf {
+private[streaming] object StateStoreConf {
val empty = new StateStoreConf()
}
-
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
index 5aa0636850..e418217238 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
-import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
+import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.util.RpcUtils
@@ -50,8 +50,7 @@ private[sql] object StateStoreCoordinatorRef extends Logging {
private val endpointName = "StateStoreCoordinator"
/**
- * Create a reference to a [[StateStoreCoordinator]], This can be called from driver as well as
- * executors.
+ * Create a reference to a [[StateStoreCoordinator]]
*/
def forDriver(env: SparkEnv): StateStoreCoordinatorRef = synchronized {
try {
@@ -75,7 +74,7 @@ private[sql] object StateStoreCoordinatorRef extends Logging {
}
/**
- * Reference to a [[StateStoreCoordinator]] that can be used to coordinator instances of
+ * Reference to a [[StateStoreCoordinator]] that can be used to coordinate instances of
* [[StateStore]]s across all the executors, and get their locations for job scheduling.
*/
private[sql] class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) {
@@ -112,7 +111,7 @@ private[sql] class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointR
* Class for coordinating instances of [[StateStore]]s loaded in executors across the cluster,
* and get their locations for job scheduling.
*/
-private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends RpcEndpoint {
+private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint {
private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation]
override def receive: PartialFunction[Any, Unit] = {
@@ -142,5 +141,3 @@ private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends RpcEndp
context.reply(true)
}
}
-
-
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
index 3318660895..d708486d8e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
@@ -22,12 +22,12 @@ import scala.reflect.ClassTag
import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.StructType
-import org.apache.spark.util.{SerializableConfiguration, Utils}
+import org.apache.spark.util.SerializableConfiguration
/**
* An RDD that allows computations to be executed against [[StateStore]]s. It
- * uses the [[StateStoreCoordinator]] to use the locations of loaded state stores as
- * preferred locations.
+ * uses the [[StateStoreCoordinator]] to get the locations of loaded state stores
+ * and use that as the preferred locations.
*/
class StateStoreRDD[T: ClassTag, U: ClassTag](
dataRDD: RDD[T],
@@ -54,17 +54,10 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = {
var store: StateStore = null
-
- Utils.tryWithSafeFinally {
- val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)
- store = StateStore.get(
- storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value)
- val inputIter = dataRDD.iterator(partition, ctxt)
- val outputIter = storeUpdateFunction(store, inputIter)
- assert(store.hasCommitted)
- outputIter
- } {
- if (store != null) store.cancel()
- }
+ val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)
+ store = StateStore.get(
+ storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value)
+ val inputIter = dataRDD.iterator(partition, ctxt)
+ storeUpdateFunction(store, inputIter)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
index b249e37921..9b6d0918e2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
@@ -28,37 +28,36 @@ package object state {
implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) {
/** Map each partition of a RDD along with data in a [[StateStore]]. */
- def mapPartitionWithStateStore[U: ClassTag](
- storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U],
+ def mapPartitionsWithStateStore[U: ClassTag](
+ sqlContext: SQLContext,
checkpointLocation: String,
operatorId: Long,
storeVersion: Long,
keySchema: StructType,
- valueSchema: StructType
- )(implicit sqlContext: SQLContext): StateStoreRDD[T, U] = {
+ valueSchema: StructType)(
+ storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {
- mapPartitionWithStateStore(
- storeUpdateFunction,
+ mapPartitionsWithStateStore(
checkpointLocation,
operatorId,
storeVersion,
keySchema,
valueSchema,
new StateStoreConf(sqlContext.conf),
- Some(sqlContext.streams.stateStoreCoordinator))
+ Some(sqlContext.streams.stateStoreCoordinator))(
+ storeUpdateFunction)
}
/** Map each partition of a RDD along with data in a [[StateStore]]. */
- private[state] def mapPartitionWithStateStore[U: ClassTag](
- storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U],
+ private[streaming] def mapPartitionsWithStateStore[U: ClassTag](
checkpointLocation: String,
operatorId: Long,
storeVersion: Long,
keySchema: StructType,
valueSchema: StructType,
storeConf: StateStoreConf,
- storeCoordinator: Option[StateStoreCoordinatorRef]
- ): StateStoreRDD[T, U] = {
+ storeCoordinator: Option[StateStoreCoordinatorRef])(
+ storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {
val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction)
new StateStoreRDD(
dataRDD,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
index 0d580703f5..4b3091ba22 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
@@ -17,12 +17,12 @@
package org.apache.spark.sql.execution
+import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{ExprId, Literal, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.types.DataType
/**
@@ -60,14 +60,13 @@ case class ScalarSubquery(
}
/**
- * Convert the subquery from logical plan into executed plan.
+ * Plans scalar subqueries from that are present in the given [[SparkPlan]].
*/
-case class PlanSubqueries(sessionState: SessionState) extends Rule[SparkPlan] {
+case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
plan.transformAllExpressions {
case subquery: expressions.ScalarSubquery =>
- val sparkPlan = sessionState.planner.plan(ReturnAnswer(subquery.query)).next()
- val executedPlan = sessionState.prepareForExecution.execute(sparkPlan)
+ val executedPlan = new QueryExecution(sqlContext, subquery.plan).executedPlan
ScalarSubquery(executedPlan, subquery.exprId)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala
index d3e823fdeb..e96fb9f755 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala
@@ -55,6 +55,12 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L
}
_content
}
+ content ++=
+ <script>
+ function clickDetail(details) {{
+ details.parentNode.querySelector('.stage-details').classList.toggle('collapsed')
+ }}
+ </script>
UIUtils.headerSparkPage("SQL", content, parent, Some(5000))
}
}
@@ -118,14 +124,12 @@ private[ui] abstract class ExecutionTable(
{failedJobs}
</td>
}}
- {detailCell(executionUIData.physicalPlanDescription)}
</tr>
}
private def descriptionCell(execution: SQLExecutionUIData): Seq[Node] = {
val details = if (execution.details.nonEmpty) {
- <span onclick="this.parentNode.querySelector('.stage-details').classList.toggle('collapsed')"
- class="expand-details">
+ <span onclick="clickDetail(this)" class="expand-details">
+details
</span> ++
<div class="stage-details collapsed">
@@ -142,30 +146,6 @@ private[ui] abstract class ExecutionTable(
<div>{desc} {details}</div>
}
- private def detailCell(physicalPlan: String): Seq[Node] = {
- val isMultiline = physicalPlan.indexOf('\n') >= 0
- val summary = StringEscapeUtils.escapeHtml4(
- if (isMultiline) {
- physicalPlan.substring(0, physicalPlan.indexOf('\n'))
- } else {
- physicalPlan
- })
- val details = if (isMultiline) {
- // scalastyle:off
- <span onclick="this.parentNode.querySelector('.stacktrace-details').classList.toggle('collapsed')"
- class="expand-details">
- +details
- </span> ++
- <div class="stacktrace-details collapsed">
- <pre>{physicalPlan}</pre>
- </div>
- // scalastyle:on
- } else {
- ""
- }
- <td>{summary}{details}</td>
- }
-
def toNodeSeq: Seq[Node] = {
<div>
<h4>{tableName}</h4>
@@ -197,7 +177,7 @@ private[ui] class RunningExecutionTable(
showFailedJobs = true) {
override protected def header: Seq[String] =
- baseHeader ++ Seq("Running Jobs", "Succeeded Jobs", "Failed Jobs", "Detail")
+ baseHeader ++ Seq("Running Jobs", "Succeeded Jobs", "Failed Jobs")
}
private[ui] class CompletedExecutionTable(
@@ -215,7 +195,7 @@ private[ui] class CompletedExecutionTable(
showSucceededJobs = true,
showFailedJobs = false) {
- override protected def header: Seq[String] = baseHeader ++ Seq("Jobs", "Detail")
+ override protected def header: Seq[String] = baseHeader ++ Seq("Jobs")
}
private[ui] class FailedExecutionTable(
@@ -234,5 +214,5 @@ private[ui] class FailedExecutionTable(
showFailedJobs = true) {
override protected def header: Seq[String] =
- baseHeader ++ Seq("Succeeded Jobs", "Failed Jobs", "Detail")
+ baseHeader ++ Seq("Succeeded Jobs", "Failed Jobs")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
index 24a01f5be1..c6fcb6956c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
@@ -45,8 +45,8 @@ private[ui] case class SparkPlanGraph(
}
/**
- * All the SparkPlanGraphNodes, including those inside of WholeStageCodegen.
- */
+ * All the SparkPlanGraphNodes, including those inside of WholeStageCodegen.
+ */
val allNodes: Seq[SparkPlanGraphNode] = {
nodes.flatMap {
case cluster: SparkPlanGraphCluster => cluster.nodes :+ cluster
@@ -167,8 +167,8 @@ private[ui] class SparkPlanGraphNode(
}
/**
- * Represent a tree of SparkPlan for WholeStageCodegen.
- */
+ * Represent a tree of SparkPlan for WholeStageCodegen.
+ */
private[ui] class SparkPlanGraphCluster(
id: Long,
name: String,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 844f3051fa..7da8379c9a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -43,52 +43,65 @@ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
*
* Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird
*
- * @tparam I The input type for the aggregation.
- * @tparam B The type of the intermediate value of the reduction.
- * @tparam O The type of the final output result.
+ * @tparam IN The input type for the aggregation.
+ * @tparam BUF The type of the intermediate value of the reduction.
+ * @tparam OUT The type of the final output result.
* @since 1.6.0
*/
-abstract class Aggregator[-I, B, O] extends Serializable {
+abstract class Aggregator[-IN, BUF, OUT] extends Serializable {
/**
* A zero value for this aggregation. Should satisfy the property that any b + zero = b.
* @since 1.6.0
*/
- def zero: B
+ def zero: BUF
/**
* Combine two values to produce a new value. For performance, the function may modify `b` and
* return it instead of constructing new object for b.
* @since 1.6.0
*/
- def reduce(b: B, a: I): B
+ def reduce(b: BUF, a: IN): BUF
/**
* Merge two intermediate values.
* @since 1.6.0
*/
- def merge(b1: B, b2: B): B
+ def merge(b1: BUF, b2: BUF): BUF
/**
* Transform the output of the reduction.
* @since 1.6.0
*/
- def finish(reduction: B): O
+ def finish(reduction: BUF): OUT
/**
- * Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]] or [[DataFrame]]
+ * Specifies the [[Encoder]] for the intermediate value type.
+ * @since 2.0.0
+ */
+ def bufferEncoder: Encoder[BUF]
+
+ /**
+ * Specifies the [[Encoder]] for the final ouput value type.
+ * @since 2.0.0
+ */
+ def outputEncoder: Encoder[OUT]
+
+ /**
+ * Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]].
* operations.
* @since 1.6.0
*/
- def toColumn(
- implicit bEncoder: Encoder[B],
- cEncoder: Encoder[O]): TypedColumn[I, O] = {
+ def toColumn: TypedColumn[IN, OUT] = {
+ implicit val bEncoder = bufferEncoder
+ implicit val cEncoder = outputEncoder
+
val expr =
- new AggregateExpression(
+ AggregateExpression(
TypedAggregateExpression(this),
Complete,
- false)
+ isDistinct = false)
- new TypedColumn[I, O](expr, encoderFor[O])
+ new TypedColumn[IN, OUT](expr, encoderFor[OUT])
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
index e9b60841fc..350c283646 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
@@ -42,7 +42,7 @@ object Window {
* Creates a [[WindowSpec]] with the partitioning defined.
* @since 1.4.0
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def partitionBy(colName: String, colNames: String*): WindowSpec = {
spec.partitionBy(colName, colNames : _*)
}
@@ -51,7 +51,7 @@ object Window {
* Creates a [[WindowSpec]] with the partitioning defined.
* @since 1.4.0
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def partitionBy(cols: Column*): WindowSpec = {
spec.partitionBy(cols : _*)
}
@@ -60,7 +60,7 @@ object Window {
* Creates a [[WindowSpec]] with the ordering defined.
* @since 1.4.0
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def orderBy(colName: String, colNames: String*): WindowSpec = {
spec.orderBy(colName, colNames : _*)
}
@@ -69,7 +69,7 @@ object Window {
* Creates a [[WindowSpec]] with the ordering defined.
* @since 1.4.0
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def orderBy(cols: Column*): WindowSpec = {
spec.orderBy(cols : _*)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
index 9e9c58cb66..d716da2668 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
@@ -39,7 +39,7 @@ class WindowSpec private[sql](
* Defines the partitioning columns in a [[WindowSpec]].
* @since 1.4.0
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def partitionBy(colName: String, colNames: String*): WindowSpec = {
partitionBy((colName +: colNames).map(Column(_)): _*)
}
@@ -48,7 +48,7 @@ class WindowSpec private[sql](
* Defines the partitioning columns in a [[WindowSpec]].
* @since 1.4.0
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def partitionBy(cols: Column*): WindowSpec = {
new WindowSpec(cols.map(_.expr), orderSpec, frame)
}
@@ -57,7 +57,7 @@ class WindowSpec private[sql](
* Defines the ordering columns in a [[WindowSpec]].
* @since 1.4.0
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def orderBy(colName: String, colNames: String*): WindowSpec = {
orderBy((colName +: colNames).map(Column(_)): _*)
}
@@ -66,7 +66,7 @@ class WindowSpec private[sql](
* Defines the ordering columns in a [[WindowSpec]].
* @since 1.4.0
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def orderBy(cols: Column*): WindowSpec = {
val sortOrder: Seq[SortOrder] = cols.map { col =>
col.expr match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala
new file mode 100644
index 0000000000..d0eb190afd
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.expressions.scala
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql._
+import org.apache.spark.sql.execution.aggregate._
+
+/**
+ * :: Experimental ::
+ * Type-safe functions available for [[Dataset]] operations in Scala.
+ *
+ * Java users should use [[org.apache.spark.sql.expressions.java.typed]].
+ *
+ * @since 2.0.0
+ */
+@Experimental
+// scalastyle:off
+object typed {
+ // scalastyle:on
+
+ // Note: whenever we update this file, we should update the corresponding Java version too.
+ // The reason we have separate files for Java and Scala is because in the Scala version, we can
+ // use tighter types (primitive types) for return types, whereas in the Java version we can only
+ // use boxed primitive types.
+ // For example, avg in the Scala veresion returns Scala primitive Double, whose bytecode
+ // signature is just a java.lang.Object; avg in the Java version returns java.lang.Double.
+
+ // TODO: This is pretty hacky. Maybe we should have an object for implicit encoders.
+ private val implicits = new SQLImplicits {
+ override protected def _sqlContext: SQLContext = null
+ }
+
+ import implicits._
+
+ /**
+ * Average aggregate function.
+ *
+ * @since 2.0.0
+ */
+ def avg[IN](f: IN => Double): TypedColumn[IN, Double] = new TypedAverage(f).toColumn
+
+ /**
+ * Count aggregate function.
+ *
+ * @since 2.0.0
+ */
+ def count[IN](f: IN => Any): TypedColumn[IN, Long] = new TypedCount(f).toColumn
+
+ /**
+ * Sum aggregate function for floating point (double) type.
+ *
+ * @since 2.0.0
+ */
+ def sum[IN](f: IN => Double): TypedColumn[IN, Double] = new TypedSumDouble[IN](f).toColumn
+
+ /**
+ * Sum aggregate function for integral (long, i.e. 64 bit integer) type.
+ *
+ * @since 2.0.0
+ */
+ def sumLong[IN](f: IN => Long): TypedColumn[IN, Long] = new TypedSumLong[IN](f).toColumn
+
+ // TODO:
+ // stddevOf: Double
+ // varianceOf: Double
+ // approxCountDistinct: Long
+
+ // minOf: T
+ // maxOf: T
+
+ // firstOf: T
+ // lastOf: T
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
index 8b355befc3..48925910ac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
@@ -106,7 +106,7 @@ abstract class UserDefinedAggregateFunction extends Serializable {
/**
* Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments.
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def apply(exprs: Column*): Column = {
val aggregateExpression =
AggregateExpression(
@@ -120,7 +120,7 @@ abstract class UserDefinedAggregateFunction extends Serializable {
* Creates a [[Column]] for this UDAF using the distinct values of the given
* [[Column]]s as input arguments.
*/
- @scala.annotation.varargs
+ @_root_.scala.annotation.varargs
def distinct(exprs: Column*): Column = {
val aggregateExpression =
AggregateExpression(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 8abb9d7e4a..223122300d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.parser.CatalystQl
import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint
+import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -154,6 +154,8 @@ object functions {
/**
* Aggregate function: returns the approximate number of distinct items in a group.
*
+ * @param rsd maximum estimation error allowed (default = 0.05)
+ *
* @group agg_funcs
* @since 1.3.0
*/
@@ -164,6 +166,8 @@ object functions {
/**
* Aggregate function: returns the approximate number of distinct items in a group.
*
+ * @param rsd maximum estimation error allowed (default = 0.05)
+ *
* @group agg_funcs
* @since 1.3.0
*/
@@ -332,95 +336,94 @@ object functions {
}
/**
- * Aggregate function: returns the first value in a group.
- *
- * The function by default returns the first values it sees. It will return the first non-null
- * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
- *
- * @group agg_funcs
- * @since 2.0.0
- */
+ * Aggregate function: returns the first value in a group.
+ *
+ * The function by default returns the first values it sees. It will return the first non-null
+ * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+ *
+ * @group agg_funcs
+ * @since 2.0.0
+ */
def first(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction {
new First(e.expr, Literal(ignoreNulls))
}
/**
- * Aggregate function: returns the first value of a column in a group.
- *
- * The function by default returns the first values it sees. It will return the first non-null
- * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
- *
- * @group agg_funcs
- * @since 2.0.0
- */
+ * Aggregate function: returns the first value of a column in a group.
+ *
+ * The function by default returns the first values it sees. It will return the first non-null
+ * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+ *
+ * @group agg_funcs
+ * @since 2.0.0
+ */
def first(columnName: String, ignoreNulls: Boolean): Column = {
first(Column(columnName), ignoreNulls)
}
/**
- * Aggregate function: returns the first value in a group.
- *
- * The function by default returns the first values it sees. It will return the first non-null
- * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
- *
- * @group agg_funcs
- * @since 1.3.0
- */
+ * Aggregate function: returns the first value in a group.
+ *
+ * The function by default returns the first values it sees. It will return the first non-null
+ * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+ *
+ * @group agg_funcs
+ * @since 1.3.0
+ */
def first(e: Column): Column = first(e, ignoreNulls = false)
/**
- * Aggregate function: returns the first value of a column in a group.
- *
- * The function by default returns the first values it sees. It will return the first non-null
- * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
- *
- * @group agg_funcs
- * @since 1.3.0
- */
+ * Aggregate function: returns the first value of a column in a group.
+ *
+ * The function by default returns the first values it sees. It will return the first non-null
+ * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+ *
+ * @group agg_funcs
+ * @since 1.3.0
+ */
def first(columnName: String): Column = first(Column(columnName))
-
/**
- * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated
- * or not, returns 1 for aggregated or 0 for not aggregated in the result set.
- *
- * @group agg_funcs
- * @since 2.0.0
- */
+ * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated
+ * or not, returns 1 for aggregated or 0 for not aggregated in the result set.
+ *
+ * @group agg_funcs
+ * @since 2.0.0
+ */
def grouping(e: Column): Column = Column(Grouping(e.expr))
/**
- * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated
- * or not, returns 1 for aggregated or 0 for not aggregated in the result set.
- *
- * @group agg_funcs
- * @since 2.0.0
- */
+ * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated
+ * or not, returns 1 for aggregated or 0 for not aggregated in the result set.
+ *
+ * @group agg_funcs
+ * @since 2.0.0
+ */
def grouping(columnName: String): Column = grouping(Column(columnName))
/**
- * Aggregate function: returns the level of grouping, equals to
- *
- * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn)
- *
- * Note: the list of columns should match with grouping columns exactly, or empty (means all the
- * grouping columns).
- *
- * @group agg_funcs
- * @since 2.0.0
- */
+ * Aggregate function: returns the level of grouping, equals to
+ *
+ * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn)
+ *
+ * Note: the list of columns should match with grouping columns exactly, or empty (means all the
+ * grouping columns).
+ *
+ * @group agg_funcs
+ * @since 2.0.0
+ */
def grouping_id(cols: Column*): Column = Column(GroupingID(cols.map(_.expr)))
/**
- * Aggregate function: returns the level of grouping, equals to
- *
- * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn)
- *
- * Note: the list of columns should match with grouping columns exactly.
- *
- * @group agg_funcs
- * @since 2.0.0
- */
+ * Aggregate function: returns the level of grouping, equals to
+ *
+ * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn)
+ *
+ * Note: the list of columns should match with grouping columns exactly.
+ *
+ * @group agg_funcs
+ * @since 2.0.0
+ */
def grouping_id(colName: String, colNames: String*): Column = {
grouping_id((Seq(colName) ++ colNames).map(n => Column(n)) : _*)
}
@@ -442,51 +445,51 @@ object functions {
def kurtosis(columnName: String): Column = kurtosis(Column(columnName))
/**
- * Aggregate function: returns the last value in a group.
- *
- * The function by default returns the last values it sees. It will return the last non-null
- * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
- *
- * @group agg_funcs
- * @since 2.0.0
- */
+ * Aggregate function: returns the last value in a group.
+ *
+ * The function by default returns the last values it sees. It will return the last non-null
+ * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+ *
+ * @group agg_funcs
+ * @since 2.0.0
+ */
def last(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction {
new Last(e.expr, Literal(ignoreNulls))
}
/**
- * Aggregate function: returns the last value of the column in a group.
- *
- * The function by default returns the last values it sees. It will return the last non-null
- * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
- *
- * @group agg_funcs
- * @since 2.0.0
- */
+ * Aggregate function: returns the last value of the column in a group.
+ *
+ * The function by default returns the last values it sees. It will return the last non-null
+ * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+ *
+ * @group agg_funcs
+ * @since 2.0.0
+ */
def last(columnName: String, ignoreNulls: Boolean): Column = {
last(Column(columnName), ignoreNulls)
}
/**
- * Aggregate function: returns the last value in a group.
- *
- * The function by default returns the last values it sees. It will return the last non-null
- * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
- *
- * @group agg_funcs
- * @since 1.3.0
- */
+ * Aggregate function: returns the last value in a group.
+ *
+ * The function by default returns the last values it sees. It will return the last non-null
+ * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+ *
+ * @group agg_funcs
+ * @since 1.3.0
+ */
def last(e: Column): Column = last(e, ignoreNulls = false)
/**
- * Aggregate function: returns the last value of the column in a group.
- *
- * The function by default returns the last values it sees. It will return the last non-null
- * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
- *
- * @group agg_funcs
- * @since 1.3.0
- */
+ * Aggregate function: returns the last value of the column in a group.
+ *
+ * The function by default returns the last values it sees. It will return the last non-null
+ * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+ *
+ * @group agg_funcs
+ * @since 1.3.0
+ */
def last(columnName: String): Column = last(Column(columnName), ignoreNulls = false)
/**
@@ -1172,8 +1175,7 @@ object functions {
* @group normal_funcs
*/
def expr(expr: String): Column = {
- val parser = SQLContext.getActive().map(_.sessionState.sqlParser).getOrElse(new CatalystQl())
- Column(parser.parseExpression(expr))
+ Column(SparkSqlParser.parseExpression(expr))
}
//////////////////////////////////////////////////////////////////////////////////////////////
@@ -2233,7 +2235,7 @@ object functions {
/**
* Splits str around pattern (pattern is a regular expression).
- * NOTE: pattern is a string represent the regular expression.
+ * NOTE: pattern is a string representation of the regular expression.
*
* @group string_funcs
* @since 1.5.0
@@ -2268,9 +2270,9 @@ object functions {
/**
* Translate any character in the src by a character in replaceString.
- * The characters in replaceString is corresponding to the characters in matchingString.
- * The translate will happen when any character in the string matching with the character
- * in the matchingString.
+ * The characters in replaceString correspond to the characters in matchingString.
+ * The translate will happen when any character in the string matches the character
+ * in the `matchingString`.
*
* @group string_funcs
* @since 1.5.0
@@ -2551,12 +2553,146 @@ object functions {
ToUTCTimestamp(ts.expr, Literal(tz))
}
+ /**
+ * Bucketize rows into one or more time windows given a timestamp specifying column. Window
+ * starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window
+ * [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in
+ * the order of months are not supported. The following example takes the average stock price for
+ * a one minute window every 10 seconds starting 5 seconds after the hour:
+ *
+ * {{{
+ * val df = ... // schema => timestamp: TimestampType, stockId: StringType, price: DoubleType
+ * df.groupBy(window($"time", "1 minute", "10 seconds", "5 seconds"), $"stockId")
+ * .agg(mean("price"))
+ * }}}
+ *
+ * The windows will look like:
+ *
+ * {{{
+ * 09:00:05-09:01:05
+ * 09:00:15-09:01:15
+ * 09:00:25-09:01:25 ...
+ * }}}
+ *
+ * For a continuous query, you may use the function `current_timestamp` to generate windows on
+ * processing time.
+ *
+ * @param timeColumn The column or the expression to use as the timestamp for windowing by time.
+ * The time column must be of TimestampType.
+ * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`,
+ * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for
+ * valid duration identifiers.
+ * @param slideDuration A string specifying the sliding interval of the window, e.g. `1 minute`.
+ * A new window will be generated every `slideDuration`. Must be less than
+ * or equal to the `windowDuration`. Check
+ * [[org.apache.spark.unsafe.types.CalendarInterval]] for valid duration
+ * identifiers.
+ * @param startTime The offset with respect to 1970-01-01 00:00:00 UTC with which to start
+ * window intervals. For example, in order to have hourly tumbling windows that
+ * start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide
+ * `startTime` as `15 minutes`.
+ *
+ * @group datetime_funcs
+ * @since 2.0.0
+ */
+ @Experimental
+ def window(
+ timeColumn: Column,
+ windowDuration: String,
+ slideDuration: String,
+ startTime: String): Column = {
+ withExpr {
+ TimeWindow(timeColumn.expr, windowDuration, slideDuration, startTime)
+ }.as("window")
+ }
+
+
+ /**
+ * Bucketize rows into one or more time windows given a timestamp specifying column. Window
+ * starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window
+ * [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in
+ * the order of months are not supported. The windows start beginning at 1970-01-01 00:00:00 UTC.
+ * The following example takes the average stock price for a one minute window every 10 seconds:
+ *
+ * {{{
+ * val df = ... // schema => timestamp: TimestampType, stockId: StringType, price: DoubleType
+ * df.groupBy(window($"time", "1 minute", "10 seconds"), $"stockId")
+ * .agg(mean("price"))
+ * }}}
+ *
+ * The windows will look like:
+ *
+ * {{{
+ * 09:00:00-09:01:00
+ * 09:00:10-09:01:10
+ * 09:00:20-09:01:20 ...
+ * }}}
+ *
+ * For a continuous query, you may use the function `current_timestamp` to generate windows on
+ * processing time.
+ *
+ * @param timeColumn The column or the expression to use as the timestamp for windowing by time.
+ * The time column must be of TimestampType.
+ * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`,
+ * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for
+ * valid duration identifiers.
+ * @param slideDuration A string specifying the sliding interval of the window, e.g. `1 minute`.
+ * A new window will be generated every `slideDuration`. Must be less than
+ * or equal to the `windowDuration`. Check
+ * [[org.apache.spark.unsafe.types.CalendarInterval]] for valid duration.
+ *
+ * @group datetime_funcs
+ * @since 2.0.0
+ */
+ @Experimental
+ def window(timeColumn: Column, windowDuration: String, slideDuration: String): Column = {
+ window(timeColumn, windowDuration, slideDuration, "0 second")
+ }
+
+ /**
+ * Generates tumbling time windows given a timestamp specifying column. Window
+ * starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window
+ * [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in
+ * the order of months are not supported. The windows start beginning at 1970-01-01 00:00:00 UTC.
+ * The following example takes the average stock price for a one minute tumbling window:
+ *
+ * {{{
+ * val df = ... // schema => timestamp: TimestampType, stockId: StringType, price: DoubleType
+ * df.groupBy(window($"time", "1 minute"), $"stockId")
+ * .agg(mean("price"))
+ * }}}
+ *
+ * The windows will look like:
+ *
+ * {{{
+ * 09:00:00-09:01:00
+ * 09:01:00-09:02:00
+ * 09:02:00-09:03:00 ...
+ * }}}
+ *
+ * For a continuous query, you may use the function `current_timestamp` to generate windows on
+ * processing time.
+ *
+ * @param timeColumn The column or the expression to use as the timestamp for windowing by time.
+ * The time column must be of TimestampType.
+ * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`,
+ * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for
+ * valid duration identifiers.
+ *
+ * @group datetime_funcs
+ * @since 2.0.0
+ */
+ @Experimental
+ def window(timeColumn: Column, windowDuration: String): Column = {
+ window(timeColumn, windowDuration, windowDuration, "0 second")
+ }
+
//////////////////////////////////////////////////////////////////////////////////////////////
// Collection functions
//////////////////////////////////////////////////////////////////////////////////////////////
/**
- * Returns true if the array contain the value
+ * Returns true if the array contains `value`
* @group collection_funcs
* @since 1.5.0
*/
@@ -2784,7 +2920,7 @@ object functions {
/**
* Defines a user-defined function (UDF) using a Scala closure. For this variant, the caller must
- * specifcy the output data type, and there is no automatic input type coercion.
+ * specify the output data type, and there is no automatic input type coercion.
*
* @param f A closure in Scala
* @param dataType The output data type of the UDF
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 77af0e000b..2f9d63c2e8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -25,9 +25,9 @@ import scala.collection.immutable
import org.apache.parquet.hadoop.ParquetOutputCommitter
import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
+import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.catalyst.CatalystConf
-import org.apache.spark.sql.catalyst.parser.ParserConf
-import org.apache.spark.util.Utils
////////////////////////////////////////////////////////////////////////////////////////////////////
// This file defines the configuration options for Spark SQL.
@@ -37,418 +37,299 @@ import org.apache.spark.util.Utils
object SQLConf {
private val sqlConfEntries = java.util.Collections.synchronizedMap(
- new java.util.HashMap[String, SQLConfEntry[_]]())
+ new java.util.HashMap[String, ConfigEntry[_]]())
- /**
- * An entry contains all meta information for a configuration.
- *
- * @param key the key for the configuration
- * @param defaultValue the default value for the configuration
- * @param valueConverter how to convert a string to the value. It should throw an exception if the
- * string does not have the required format.
- * @param stringConverter how to convert a value to a string that the user can use it as a valid
- * string value. It's usually `toString`. But sometimes, a custom converter
- * is necessary. E.g., if T is List[String], `a, b, c` is better than
- * `List(a, b, c)`.
- * @param doc the document for the configuration
- * @param isPublic if this configuration is public to the user. If it's `false`, this
- * configuration is only used internally and we should not expose it to the user.
- * @tparam T the value type
- */
- class SQLConfEntry[T] private(
- val key: String,
- val defaultValue: Option[T],
- val valueConverter: String => T,
- val stringConverter: T => String,
- val doc: String,
- val isPublic: Boolean) {
-
- def defaultValueString: String = defaultValue.map(stringConverter).getOrElse("<undefined>")
-
- override def toString: String = {
- s"SQLConfEntry(key = $key, defaultValue=$defaultValueString, doc=$doc, isPublic = $isPublic)"
- }
+ private def register(entry: ConfigEntry[_]): Unit = sqlConfEntries.synchronized {
+ require(!sqlConfEntries.containsKey(entry.key),
+ s"Duplicate SQLConfigEntry. ${entry.key} has been registered")
+ sqlConfEntries.put(entry.key, entry)
}
- object SQLConfEntry {
-
- private def apply[T](
- key: String,
- defaultValue: Option[T],
- valueConverter: String => T,
- stringConverter: T => String,
- doc: String,
- isPublic: Boolean): SQLConfEntry[T] =
- sqlConfEntries.synchronized {
- if (sqlConfEntries.containsKey(key)) {
- throw new IllegalArgumentException(s"Duplicate SQLConfEntry. $key has been registered")
- }
- val entry =
- new SQLConfEntry[T](key, defaultValue, valueConverter, stringConverter, doc, isPublic)
- sqlConfEntries.put(key, entry)
- entry
- }
-
- def intConf(
- key: String,
- defaultValue: Option[Int] = None,
- doc: String = "",
- isPublic: Boolean = true): SQLConfEntry[Int] =
- SQLConfEntry(key, defaultValue, { v =>
- try {
- v.toInt
- } catch {
- case _: NumberFormatException =>
- throw new IllegalArgumentException(s"$key should be int, but was $v")
- }
- }, _.toString, doc, isPublic)
-
- def longConf(
- key: String,
- defaultValue: Option[Long] = None,
- doc: String = "",
- isPublic: Boolean = true): SQLConfEntry[Long] =
- SQLConfEntry(key, defaultValue, { v =>
- try {
- v.toLong
- } catch {
- case _: NumberFormatException =>
- throw new IllegalArgumentException(s"$key should be long, but was $v")
- }
- }, _.toString, doc, isPublic)
-
- def longMemConf(
- key: String,
- defaultValue: Option[Long] = None,
- doc: String = "",
- isPublic: Boolean = true): SQLConfEntry[Long] =
- SQLConfEntry(key, defaultValue, { v =>
- try {
- v.toLong
- } catch {
- case _: NumberFormatException =>
- try {
- Utils.byteStringAsBytes(v)
- } catch {
- case _: NumberFormatException =>
- throw new IllegalArgumentException(s"$key should be long, but was $v")
- }
- }
- }, _.toString, doc, isPublic)
-
- def doubleConf(
- key: String,
- defaultValue: Option[Double] = None,
- doc: String = "",
- isPublic: Boolean = true): SQLConfEntry[Double] =
- SQLConfEntry(key, defaultValue, { v =>
- try {
- v.toDouble
- } catch {
- case _: NumberFormatException =>
- throw new IllegalArgumentException(s"$key should be double, but was $v")
- }
- }, _.toString, doc, isPublic)
-
- def booleanConf(
- key: String,
- defaultValue: Option[Boolean] = None,
- doc: String = "",
- isPublic: Boolean = true): SQLConfEntry[Boolean] =
- SQLConfEntry(key, defaultValue, { v =>
- try {
- v.toBoolean
- } catch {
- case _: IllegalArgumentException =>
- throw new IllegalArgumentException(s"$key should be boolean, but was $v")
- }
- }, _.toString, doc, isPublic)
-
- def stringConf(
- key: String,
- defaultValue: Option[String] = None,
- doc: String = "",
- isPublic: Boolean = true): SQLConfEntry[String] =
- SQLConfEntry(key, defaultValue, v => v, v => v, doc, isPublic)
-
- def enumConf[T](
- key: String,
- valueConverter: String => T,
- validValues: Set[T],
- defaultValue: Option[T] = None,
- doc: String = "",
- isPublic: Boolean = true): SQLConfEntry[T] =
- SQLConfEntry(key, defaultValue, v => {
- val _v = valueConverter(v)
- if (!validValues.contains(_v)) {
- throw new IllegalArgumentException(
- s"The value of $key should be one of ${validValues.mkString(", ")}, but was $v")
- }
- _v
- }, _.toString, doc, isPublic)
-
- def seqConf[T](
- key: String,
- valueConverter: String => T,
- defaultValue: Option[Seq[T]] = None,
- doc: String = "",
- isPublic: Boolean = true): SQLConfEntry[Seq[T]] = {
- SQLConfEntry(
- key, defaultValue, _.split(",").map(valueConverter), _.mkString(","), doc, isPublic)
- }
+ private[sql] object SQLConfigBuilder {
- def stringSeqConf(
- key: String,
- defaultValue: Option[Seq[String]] = None,
- doc: String = "",
- isPublic: Boolean = true): SQLConfEntry[Seq[String]] = {
- seqConf(key, s => s, defaultValue, doc, isPublic)
- }
- }
+ def apply(key: String): ConfigBuilder = new ConfigBuilder(key).onCreate(register)
- import SQLConfEntry._
+ }
- val ALLOW_MULTIPLE_CONTEXTS = booleanConf("spark.sql.allowMultipleContexts",
- defaultValue = Some(true),
- doc = "When set to true, creating multiple SQLContexts/HiveContexts is allowed. " +
+ val ALLOW_MULTIPLE_CONTEXTS = SQLConfigBuilder("spark.sql.allowMultipleContexts")
+ .doc("When set to true, creating multiple SQLContexts/HiveContexts is allowed. " +
"When set to false, only one SQLContext/HiveContext is allowed to be created " +
"through the constructor (new SQLContexts/HiveContexts created through newSession " +
"method is allowed). Please note that this conf needs to be set in Spark Conf. Once " +
"a SQLContext/HiveContext has been created, changing the value of this conf will not " +
- "have effect.",
- isPublic = true)
-
- val COMPRESS_CACHED = booleanConf("spark.sql.inMemoryColumnarStorage.compressed",
- defaultValue = Some(true),
- doc = "When set to true Spark SQL will automatically select a compression codec for each " +
- "column based on statistics of the data.",
- isPublic = false)
-
- val COLUMN_BATCH_SIZE = intConf("spark.sql.inMemoryColumnarStorage.batchSize",
- defaultValue = Some(10000),
- doc = "Controls the size of batches for columnar caching. Larger batch sizes can improve " +
- "memory utilization and compression, but risk OOMs when caching data.",
- isPublic = false)
+ "have effect.")
+ .booleanConf
+ .createWithDefault(true)
+
+ val COMPRESS_CACHED = SQLConfigBuilder("spark.sql.inMemoryColumnarStorage.compressed")
+ .internal()
+ .doc("When set to true Spark SQL will automatically select a compression codec for each " +
+ "column based on statistics of the data.")
+ .booleanConf
+ .createWithDefault(true)
+
+ val COLUMN_BATCH_SIZE = SQLConfigBuilder("spark.sql.inMemoryColumnarStorage.batchSize")
+ .internal()
+ .doc("Controls the size of batches for columnar caching. Larger batch sizes can improve " +
+ "memory utilization and compression, but risk OOMs when caching data.")
+ .intConf
+ .createWithDefault(10000)
val IN_MEMORY_PARTITION_PRUNING =
- booleanConf("spark.sql.inMemoryColumnarStorage.partitionPruning",
- defaultValue = Some(true),
- doc = "When true, enable partition pruning for in-memory columnar tables.",
- isPublic = false)
-
- val PREFER_SORTMERGEJOIN = booleanConf("spark.sql.join.preferSortMergeJoin",
- defaultValue = Some(true),
- doc = "When true, prefer sort merge join over shuffle hash join.",
- isPublic = false)
-
- val AUTO_BROADCASTJOIN_THRESHOLD = intConf("spark.sql.autoBroadcastJoinThreshold",
- defaultValue = Some(10 * 1024 * 1024),
- doc = "Configures the maximum size in bytes for a table that will be broadcast to all worker " +
+ SQLConfigBuilder("spark.sql.inMemoryColumnarStorage.partitionPruning")
+ .internal()
+ .doc("When true, enable partition pruning for in-memory columnar tables.")
+ .booleanConf
+ .createWithDefault(true)
+
+ val PREFER_SORTMERGEJOIN = SQLConfigBuilder("spark.sql.join.preferSortMergeJoin")
+ .internal()
+ .doc("When true, prefer sort merge join over shuffle hash join.")
+ .booleanConf
+ .createWithDefault(true)
+
+ val AUTO_BROADCASTJOIN_THRESHOLD = SQLConfigBuilder("spark.sql.autoBroadcastJoinThreshold")
+ .doc("Configures the maximum size in bytes for a table that will be broadcast to all worker " +
"nodes when performing a join. By setting this value to -1 broadcasting can be disabled. " +
"Note that currently statistics are only supported for Hive Metastore tables where the " +
"command<code>ANALYZE TABLE &lt;tableName&gt; COMPUTE STATISTICS noscan</code> has been run.")
+ .intConf
+ .createWithDefault(10 * 1024 * 1024)
- val DEFAULT_SIZE_IN_BYTES = longConf(
- "spark.sql.defaultSizeInBytes",
- doc = "The default table size used in query planning. By default, it is set to a larger " +
+ val DEFAULT_SIZE_IN_BYTES = SQLConfigBuilder("spark.sql.defaultSizeInBytes")
+ .internal()
+ .doc("The default table size used in query planning. By default, it is set to a larger " +
"value than `spark.sql.autoBroadcastJoinThreshold` to be more conservative. That is to say " +
"by default the optimizer will not choose to broadcast a table unless it knows for sure " +
- "its size is small enough.",
- isPublic = false)
+ "its size is small enough.")
+ .longConf
+ .createWithDefault(-1)
- val SHUFFLE_PARTITIONS = intConf("spark.sql.shuffle.partitions",
- defaultValue = Some(200),
- doc = "The default number of partitions to use when shuffling data for joins or aggregations.")
+ val SHUFFLE_PARTITIONS = SQLConfigBuilder("spark.sql.shuffle.partitions")
+ .doc("The default number of partitions to use when shuffling data for joins or aggregations.")
+ .intConf
+ .createWithDefault(200)
val SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE =
- longMemConf("spark.sql.adaptive.shuffle.targetPostShuffleInputSize",
- defaultValue = Some(64 * 1024 * 1024),
- doc = "The target post-shuffle input size in bytes of a task.")
+ SQLConfigBuilder("spark.sql.adaptive.shuffle.targetPostShuffleInputSize")
+ .doc("The target post-shuffle input size in bytes of a task.")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefault(64 * 1024 * 1024)
- val ADAPTIVE_EXECUTION_ENABLED = booleanConf("spark.sql.adaptive.enabled",
- defaultValue = Some(false),
- doc = "When true, enable adaptive query execution.")
+ val ADAPTIVE_EXECUTION_ENABLED = SQLConfigBuilder("spark.sql.adaptive.enabled")
+ .doc("When true, enable adaptive query execution.")
+ .booleanConf
+ .createWithDefault(false)
val SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS =
- intConf("spark.sql.adaptive.minNumPostShufflePartitions",
- defaultValue = Some(-1),
- doc = "The advisory minimal number of post-shuffle partitions provided to " +
+ SQLConfigBuilder("spark.sql.adaptive.minNumPostShufflePartitions")
+ .internal()
+ .doc("The advisory minimal number of post-shuffle partitions provided to " +
"ExchangeCoordinator. This setting is used in our test to make sure we " +
"have enough parallelism to expose issues that will not be exposed with a " +
"single partition. When the value is a non-positive value, this setting will " +
- "not be provided to ExchangeCoordinator.",
- isPublic = false)
-
- val SUBEXPRESSION_ELIMINATION_ENABLED = booleanConf("spark.sql.subexpressionElimination.enabled",
- defaultValue = Some(true),
- doc = "When true, common subexpressions will be eliminated.",
- isPublic = false)
-
- val CASE_SENSITIVE = booleanConf("spark.sql.caseSensitive",
- defaultValue = Some(true),
- doc = "Whether the query analyzer should be case sensitive or not.")
-
- val PARQUET_FILE_SCAN = booleanConf("spark.sql.parquet.fileScan",
- defaultValue = Some(true),
- doc = "Use the new FileScanRDD path for reading parquet data.",
- isPublic = false)
-
- val PARQUET_SCHEMA_MERGING_ENABLED = booleanConf("spark.sql.parquet.mergeSchema",
- defaultValue = Some(false),
- doc = "When true, the Parquet data source merges schemas collected from all data files, " +
- "otherwise the schema is picked from the summary file or a random data file " +
- "if no summary file is available.")
-
- val PARQUET_SCHEMA_RESPECT_SUMMARIES = booleanConf("spark.sql.parquet.respectSummaryFiles",
- defaultValue = Some(false),
- doc = "When true, we make assumption that all part-files of Parquet are consistent with " +
- "summary files and we will ignore them when merging schema. Otherwise, if this is " +
- "false, which is the default, we will merge all part-files. This should be considered " +
- "as expert-only option, and shouldn't be enabled before knowing what it means exactly.")
-
- val PARQUET_BINARY_AS_STRING = booleanConf("spark.sql.parquet.binaryAsString",
- defaultValue = Some(false),
- doc = "Some other Parquet-producing systems, in particular Impala and older versions of " +
+ "not be provided to ExchangeCoordinator.")
+ .intConf
+ .createWithDefault(-1)
+
+ val SUBEXPRESSION_ELIMINATION_ENABLED =
+ SQLConfigBuilder("spark.sql.subexpressionElimination.enabled")
+ .internal()
+ .doc("When true, common subexpressions will be eliminated.")
+ .booleanConf
+ .createWithDefault(true)
+
+ val CASE_SENSITIVE = SQLConfigBuilder("spark.sql.caseSensitive")
+ .doc("Whether the query analyzer should be case sensitive or not.")
+ .booleanConf
+ .createWithDefault(true)
+
+ val PARQUET_SCHEMA_MERGING_ENABLED = SQLConfigBuilder("spark.sql.parquet.mergeSchema")
+ .doc("When true, the Parquet data source merges schemas collected from all data files, " +
+ "otherwise the schema is picked from the summary file or a random data file " +
+ "if no summary file is available.")
+ .booleanConf
+ .createWithDefault(false)
+
+ val PARQUET_SCHEMA_RESPECT_SUMMARIES = SQLConfigBuilder("spark.sql.parquet.respectSummaryFiles")
+ .doc("When true, we make assumption that all part-files of Parquet are consistent with " +
+ "summary files and we will ignore them when merging schema. Otherwise, if this is " +
+ "false, which is the default, we will merge all part-files. This should be considered " +
+ "as expert-only option, and shouldn't be enabled before knowing what it means exactly.")
+ .booleanConf
+ .createWithDefault(false)
+
+ val PARQUET_BINARY_AS_STRING = SQLConfigBuilder("spark.sql.parquet.binaryAsString")
+ .doc("Some other Parquet-producing systems, in particular Impala and older versions of " +
"Spark SQL, do not differentiate between binary data and strings when writing out the " +
"Parquet schema. This flag tells Spark SQL to interpret binary data as a string to provide " +
"compatibility with these systems.")
+ .booleanConf
+ .createWithDefault(false)
- val PARQUET_INT96_AS_TIMESTAMP = booleanConf("spark.sql.parquet.int96AsTimestamp",
- defaultValue = Some(true),
- doc = "Some Parquet-producing systems, in particular Impala, store Timestamp into INT96. " +
+ val PARQUET_INT96_AS_TIMESTAMP = SQLConfigBuilder("spark.sql.parquet.int96AsTimestamp")
+ .doc("Some Parquet-producing systems, in particular Impala, store Timestamp into INT96. " +
"Spark would also store Timestamp as INT96 because we need to avoid precision lost of the " +
"nanoseconds field. This flag tells Spark SQL to interpret INT96 data as a timestamp to " +
"provide compatibility with these systems.")
+ .booleanConf
+ .createWithDefault(true)
- val PARQUET_CACHE_METADATA = booleanConf("spark.sql.parquet.cacheMetadata",
- defaultValue = Some(true),
- doc = "Turns on caching of Parquet schema metadata. Can speed up querying of static data.")
+ val PARQUET_CACHE_METADATA = SQLConfigBuilder("spark.sql.parquet.cacheMetadata")
+ .doc("Turns on caching of Parquet schema metadata. Can speed up querying of static data.")
+ .booleanConf
+ .createWithDefault(true)
- val PARQUET_COMPRESSION = enumConf("spark.sql.parquet.compression.codec",
- valueConverter = v => v.toLowerCase,
- validValues = Set("uncompressed", "snappy", "gzip", "lzo"),
- defaultValue = Some("gzip"),
- doc = "Sets the compression codec use when writing Parquet files. Acceptable values include: " +
+ val PARQUET_COMPRESSION = SQLConfigBuilder("spark.sql.parquet.compression.codec")
+ .doc("Sets the compression codec use when writing Parquet files. Acceptable values include: " +
"uncompressed, snappy, gzip, lzo.")
-
- val PARQUET_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.parquet.filterPushdown",
- defaultValue = Some(true),
- doc = "Enables Parquet filter push-down optimization when set to true.")
-
- val PARQUET_WRITE_LEGACY_FORMAT = booleanConf(
- key = "spark.sql.parquet.writeLegacyFormat",
- defaultValue = Some(false),
- doc = "Whether to follow Parquet's format specification when converting Parquet schema to " +
+ .stringConf
+ .transform(_.toLowerCase())
+ .checkValues(Set("uncompressed", "snappy", "gzip", "lzo"))
+ .createWithDefault("snappy")
+
+ val PARQUET_FILTER_PUSHDOWN_ENABLED = SQLConfigBuilder("spark.sql.parquet.filterPushdown")
+ .doc("Enables Parquet filter push-down optimization when set to true.")
+ .booleanConf
+ .createWithDefault(true)
+
+ val PARQUET_WRITE_LEGACY_FORMAT = SQLConfigBuilder("spark.sql.parquet.writeLegacyFormat")
+ .doc("Whether to follow Parquet's format specification when converting Parquet schema to " +
"Spark SQL schema and vice versa.")
+ .booleanConf
+ .createWithDefault(false)
- val PARQUET_OUTPUT_COMMITTER_CLASS = stringConf(
- key = "spark.sql.parquet.output.committer.class",
- defaultValue = Some(classOf[ParquetOutputCommitter].getName),
- doc = "The output committer class used by Parquet. The specified class needs to be a " +
+ val PARQUET_OUTPUT_COMMITTER_CLASS = SQLConfigBuilder("spark.sql.parquet.output.committer.class")
+ .doc("The output committer class used by Parquet. The specified class needs to be a " +
"subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " +
"of org.apache.parquet.hadoop.ParquetOutputCommitter. NOTE: 1. Instead of SQLConf, this " +
"option must be set in Hadoop Configuration. 2. This option overrides " +
"\"spark.sql.sources.outputCommitterClass\".")
-
- val PARQUET_VECTORIZED_READER_ENABLED = booleanConf(
- key = "spark.sql.parquet.enableVectorizedReader",
- defaultValue = Some(true),
- doc = "Enables vectorized parquet decoding.")
-
- val ORC_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.orc.filterPushdown",
- defaultValue = Some(false),
- doc = "When true, enable filter pushdown for ORC files.")
-
- val HIVE_VERIFY_PARTITION_PATH = booleanConf("spark.sql.hive.verifyPartitionPath",
- defaultValue = Some(false),
- doc = "When true, check all the partition paths under the table\'s root directory " +
- "when reading data stored in HDFS.")
-
- val HIVE_METASTORE_PARTITION_PRUNING = booleanConf("spark.sql.hive.metastorePartitionPruning",
- defaultValue = Some(false),
- doc = "When true, some predicates will be pushed down into the Hive metastore so that " +
- "unmatching partitions can be eliminated earlier.")
-
- val NATIVE_VIEW = booleanConf("spark.sql.nativeView",
- defaultValue = Some(true),
- doc = "When true, CREATE VIEW will be handled by Spark SQL instead of Hive native commands. " +
- "Note that this function is experimental and should ony be used when you are using " +
- "non-hive-compatible tables written by Spark SQL. The SQL string used to create " +
- "view should be fully qualified, i.e. use `tbl1`.`col1` instead of `*` whenever " +
- "possible, or you may get wrong result.",
- isPublic = false)
-
- val CANONICAL_NATIVE_VIEW = booleanConf("spark.sql.nativeView.canonical",
- defaultValue = Some(true),
- doc = "When this option and spark.sql.nativeView are both true, Spark SQL tries to handle " +
- "CREATE VIEW statement using SQL query string generated from view definition logical " +
- "plan. If the logical plan doesn't have a SQL representation, we fallback to the " +
- "original native view implementation.",
- isPublic = false)
-
- val COLUMN_NAME_OF_CORRUPT_RECORD = stringConf("spark.sql.columnNameOfCorruptRecord",
- defaultValue = Some("_corrupt_record"),
- doc = "The name of internal column for storing raw/un-parsed JSON records that fail to parse.")
-
- val BROADCAST_TIMEOUT = intConf("spark.sql.broadcastTimeout",
- defaultValue = Some(5 * 60),
- doc = "Timeout in seconds for the broadcast wait time in broadcast joins.")
+ .stringConf
+ .createWithDefault(classOf[ParquetOutputCommitter].getName)
+
+ val PARQUET_VECTORIZED_READER_ENABLED =
+ SQLConfigBuilder("spark.sql.parquet.enableVectorizedReader")
+ .doc("Enables vectorized parquet decoding.")
+ .booleanConf
+ .createWithDefault(true)
+
+ val ORC_FILTER_PUSHDOWN_ENABLED = SQLConfigBuilder("spark.sql.orc.filterPushdown")
+ .doc("When true, enable filter pushdown for ORC files.")
+ .booleanConf
+ .createWithDefault(false)
+
+ val HIVE_VERIFY_PARTITION_PATH = SQLConfigBuilder("spark.sql.hive.verifyPartitionPath")
+ .doc("When true, check all the partition paths under the table\'s root directory " +
+ "when reading data stored in HDFS.")
+ .booleanConf
+ .createWithDefault(false)
+
+ val HIVE_METASTORE_PARTITION_PRUNING =
+ SQLConfigBuilder("spark.sql.hive.metastorePartitionPruning")
+ .doc("When true, some predicates will be pushed down into the Hive metastore so that " +
+ "unmatching partitions can be eliminated earlier.")
+ .booleanConf
+ .createWithDefault(false)
+
+ val NATIVE_VIEW = SQLConfigBuilder("spark.sql.nativeView")
+ .internal()
+ .doc("When true, CREATE VIEW will be handled by Spark SQL instead of Hive native commands. " +
+ "Note that this function is experimental and should ony be used when you are using " +
+ "non-hive-compatible tables written by Spark SQL. The SQL string used to create " +
+ "view should be fully qualified, i.e. use `tbl1`.`col1` instead of `*` whenever " +
+ "possible, or you may get wrong result.")
+ .booleanConf
+ .createWithDefault(true)
+
+ val CANONICAL_NATIVE_VIEW = SQLConfigBuilder("spark.sql.nativeView.canonical")
+ .internal()
+ .doc("When this option and spark.sql.nativeView are both true, Spark SQL tries to handle " +
+ "CREATE VIEW statement using SQL query string generated from view definition logical " +
+ "plan. If the logical plan doesn't have a SQL representation, we fallback to the " +
+ "original native view implementation.")
+ .booleanConf
+ .createWithDefault(true)
+
+ val COLUMN_NAME_OF_CORRUPT_RECORD = SQLConfigBuilder("spark.sql.columnNameOfCorruptRecord")
+ .doc("The name of internal column for storing raw/un-parsed JSON records that fail to parse.")
+ .stringConf
+ .createWithDefault("_corrupt_record")
+
+ val BROADCAST_TIMEOUT = SQLConfigBuilder("spark.sql.broadcastTimeout")
+ .doc("Timeout in seconds for the broadcast wait time in broadcast joins.")
+ .intConf
+ .createWithDefault(5 * 60)
// This is only used for the thriftserver
- val THRIFTSERVER_POOL = stringConf("spark.sql.thriftserver.scheduler.pool",
- doc = "Set a Fair Scheduler pool for a JDBC client session.")
-
- val THRIFTSERVER_UI_STATEMENT_LIMIT = intConf("spark.sql.thriftserver.ui.retainedStatements",
- defaultValue = Some(200),
- doc = "The number of SQL statements kept in the JDBC/ODBC web UI history.")
-
- val THRIFTSERVER_UI_SESSION_LIMIT = intConf("spark.sql.thriftserver.ui.retainedSessions",
- defaultValue = Some(200),
- doc = "The number of SQL client sessions kept in the JDBC/ODBC web UI history.")
+ val THRIFTSERVER_POOL = SQLConfigBuilder("spark.sql.thriftserver.scheduler.pool")
+ .doc("Set a Fair Scheduler pool for a JDBC client session.")
+ .stringConf
+ .createOptional
+
+ val THRIFTSERVER_UI_STATEMENT_LIMIT =
+ SQLConfigBuilder("spark.sql.thriftserver.ui.retainedStatements")
+ .doc("The number of SQL statements kept in the JDBC/ODBC web UI history.")
+ .intConf
+ .createWithDefault(200)
+
+ val THRIFTSERVER_UI_SESSION_LIMIT = SQLConfigBuilder("spark.sql.thriftserver.ui.retainedSessions")
+ .doc("The number of SQL client sessions kept in the JDBC/ODBC web UI history.")
+ .intConf
+ .createWithDefault(200)
// This is used to set the default data source
- val DEFAULT_DATA_SOURCE_NAME = stringConf("spark.sql.sources.default",
- defaultValue = Some("org.apache.spark.sql.parquet"),
- doc = "The default data source to use in input/output.")
+ val DEFAULT_DATA_SOURCE_NAME = SQLConfigBuilder("spark.sql.sources.default")
+ .doc("The default data source to use in input/output.")
+ .stringConf
+ .createWithDefault("org.apache.spark.sql.parquet")
// This is used to control the when we will split a schema's JSON string to multiple pieces
// in order to fit the JSON string in metastore's table property (by default, the value has
// a length restriction of 4000 characters). We will split the JSON string of a schema
// to its length exceeds the threshold.
- val SCHEMA_STRING_LENGTH_THRESHOLD = intConf("spark.sql.sources.schemaStringLengthThreshold",
- defaultValue = Some(4000),
- doc = "The maximum length allowed in a single cell when " +
- "storing additional schema information in Hive's metastore.",
- isPublic = false)
-
- val PARTITION_DISCOVERY_ENABLED = booleanConf("spark.sql.sources.partitionDiscovery.enabled",
- defaultValue = Some(true),
- doc = "When true, automatically discover data partitions.")
+ val SCHEMA_STRING_LENGTH_THRESHOLD =
+ SQLConfigBuilder("spark.sql.sources.schemaStringLengthThreshold")
+ .doc("The maximum length allowed in a single cell when " +
+ "storing additional schema information in Hive's metastore.")
+ .internal()
+ .intConf
+ .createWithDefault(4000)
+
+ val PARTITION_DISCOVERY_ENABLED = SQLConfigBuilder("spark.sql.sources.partitionDiscovery.enabled")
+ .doc("When true, automatically discover data partitions.")
+ .booleanConf
+ .createWithDefault(true)
val PARTITION_COLUMN_TYPE_INFERENCE =
- booleanConf("spark.sql.sources.partitionColumnTypeInference.enabled",
- defaultValue = Some(true),
- doc = "When true, automatically infer the data types for partitioned columns.")
+ SQLConfigBuilder("spark.sql.sources.partitionColumnTypeInference.enabled")
+ .doc("When true, automatically infer the data types for partitioned columns.")
+ .booleanConf
+ .createWithDefault(true)
val PARTITION_MAX_FILES =
- intConf("spark.sql.sources.maxConcurrentWrites",
- defaultValue = Some(1),
- doc = "The maximum number of concurrent files to open before falling back on sorting when " +
+ SQLConfigBuilder("spark.sql.sources.maxConcurrentWrites")
+ .doc("The maximum number of concurrent files to open before falling back on sorting when " +
"writing out files using dynamic partitioning.")
-
- val BUCKETING_ENABLED = booleanConf("spark.sql.sources.bucketing.enabled",
- defaultValue = Some(true),
- doc = "When false, we will treat bucketed table as normal table.")
-
- val ORDER_BY_ORDINAL = booleanConf("spark.sql.orderByOrdinal",
- defaultValue = Some(true),
- doc = "When true, the ordinal numbers are treated as the position in the select list. " +
- "When false, the ordinal numbers in order/sort By clause are ignored.")
-
- val GROUP_BY_ORDINAL = booleanConf("spark.sql.groupByOrdinal",
- defaultValue = Some(true),
- doc = "When true, the ordinal numbers in group by clauses are treated as the position " +
+ .intConf
+ .createWithDefault(1)
+
+ val BUCKETING_ENABLED = SQLConfigBuilder("spark.sql.sources.bucketing.enabled")
+ .doc("When false, we will treat bucketed table as normal table")
+ .booleanConf
+ .createWithDefault(true)
+
+ val ORDER_BY_ORDINAL = SQLConfigBuilder("spark.sql.orderByOrdinal")
+ .doc("When true, the ordinal numbers are treated as the position in the select list. " +
+ "When false, the ordinal numbers in order/sort By clause are ignored.")
+ .booleanConf
+ .createWithDefault(true)
+
+ val GROUP_BY_ORDINAL = SQLConfigBuilder("spark.sql.groupByOrdinal")
+ .doc("When true, the ordinal numbers in group by clauses are treated as the position " +
"in the select list. When false, the ordinal numbers are ignored.")
+ .booleanConf
+ .createWithDefault(true)
// The output committer class used by HadoopFsRelation. The specified class needs to be a
// subclass of org.apache.hadoop.mapreduce.OutputCommitter.
@@ -458,94 +339,102 @@ object SQLConf {
// 1. Instead of SQLConf, this option *must be set in Hadoop Configuration*.
// 2. This option can be overridden by "spark.sql.parquet.output.committer.class".
val OUTPUT_COMMITTER_CLASS =
- stringConf("spark.sql.sources.outputCommitterClass", isPublic = false)
+ SQLConfigBuilder("spark.sql.sources.outputCommitterClass").internal().stringConf.createOptional
- val PARALLEL_PARTITION_DISCOVERY_THRESHOLD = intConf(
- key = "spark.sql.sources.parallelPartitionDiscovery.threshold",
- defaultValue = Some(32),
- doc = "The degree of parallelism for schema merging and partition discovery of " +
- "Parquet data sources.")
+ val PARALLEL_PARTITION_DISCOVERY_THRESHOLD =
+ SQLConfigBuilder("spark.sql.sources.parallelPartitionDiscovery.threshold")
+ .doc("The degree of parallelism for schema merging and partition discovery of " +
+ "Parquet data sources.")
+ .intConf
+ .createWithDefault(32)
// Whether to perform eager analysis when constructing a dataframe.
// Set to false when debugging requires the ability to look at invalid query plans.
- val DATAFRAME_EAGER_ANALYSIS = booleanConf(
- "spark.sql.eagerAnalysis",
- defaultValue = Some(true),
- doc = "When true, eagerly applies query analysis on DataFrame operations.",
- isPublic = false)
+ val DATAFRAME_EAGER_ANALYSIS = SQLConfigBuilder("spark.sql.eagerAnalysis")
+ .internal()
+ .doc("When true, eagerly applies query analysis on DataFrame operations.")
+ .booleanConf
+ .createWithDefault(true)
// Whether to automatically resolve ambiguity in join conditions for self-joins.
// See SPARK-6231.
- val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = booleanConf(
- "spark.sql.selfJoinAutoResolveAmbiguity",
- defaultValue = Some(true),
- isPublic = false)
+ val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY =
+ SQLConfigBuilder("spark.sql.selfJoinAutoResolveAmbiguity")
+ .internal()
+ .booleanConf
+ .createWithDefault(true)
// Whether to retain group by columns or not in GroupedData.agg.
- val DATAFRAME_RETAIN_GROUP_COLUMNS = booleanConf(
- "spark.sql.retainGroupColumns",
- defaultValue = Some(true),
- isPublic = false)
-
- val DATAFRAME_PIVOT_MAX_VALUES = intConf(
- "spark.sql.pivotMaxValues",
- defaultValue = Some(10000),
- doc = "When doing a pivot without specifying values for the pivot column this is the maximum " +
- "number of (distinct) values that will be collected without error."
- )
-
- val RUN_SQL_ON_FILES = booleanConf("spark.sql.runSQLOnFiles",
- defaultValue = Some(true),
- isPublic = false,
- doc = "When true, we could use `datasource`.`path` as table in SQL query."
- )
-
- val PARSER_SUPPORT_QUOTEDID = booleanConf("spark.sql.parser.supportQuotedIdentifiers",
- defaultValue = Some(true),
- isPublic = false,
- doc = "Whether to use quoted identifier.\n false: default(past) behavior. Implies only" +
- "alphaNumeric and underscore are valid characters in identifiers.\n" +
- " true: implies column names can contain any character.")
-
- val PARSER_SUPPORT_SQL11_RESERVED_KEYWORDS = booleanConf(
- "spark.sql.parser.supportSQL11ReservedKeywords",
- defaultValue = Some(false),
- isPublic = false,
- doc = "This flag should be set to true to enable support for SQL2011 reserved keywords.")
-
- val WHOLESTAGE_CODEGEN_ENABLED = booleanConf("spark.sql.codegen.wholeStage",
- defaultValue = Some(true),
- doc = "When true, the whole stage (of multiple operators) will be compiled into single java" +
- " method.",
- isPublic = false)
-
- val FILES_MAX_PARTITION_BYTES = longConf("spark.sql.files.maxPartitionBytes",
- defaultValue = Some(128 * 1024 * 1024), // parquet.block.size
- doc = "The maximum number of bytes to pack into a single partition when reading files.",
- isPublic = true)
-
- val EXCHANGE_REUSE_ENABLED = booleanConf("spark.sql.exchange.reuse",
- defaultValue = Some(true),
- doc = "When true, the planner will try to find out duplicated exchanges and re-use them.",
- isPublic = false)
-
- val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT = intConf(
- "spark.sql.streaming.stateStore.minDeltasForSnapshot",
- defaultValue = Some(10),
- doc = "Minimum number of state store delta files that needs to be generated before they " +
- "consolidated into snapshots.",
- isPublic = false)
-
- val STATE_STORE_MIN_VERSIONS_TO_RETAIN = intConf(
- "spark.sql.streaming.stateStore.minBatchesToRetain",
- defaultValue = Some(2),
- doc = "Minimum number of versions of a state store's data to retain after cleaning.",
- isPublic = false)
-
- val CHECKPOINT_LOCATION = stringConf("spark.sql.streaming.checkpointLocation",
- defaultValue = None,
- doc = "The default location for storing checkpoint data for continuously executing queries.",
- isPublic = true)
+ val DATAFRAME_RETAIN_GROUP_COLUMNS = SQLConfigBuilder("spark.sql.retainGroupColumns")
+ .internal()
+ .booleanConf
+ .createWithDefault(true)
+
+ val DATAFRAME_PIVOT_MAX_VALUES = SQLConfigBuilder("spark.sql.pivotMaxValues")
+ .doc("When doing a pivot without specifying values for the pivot column this is the maximum " +
+ "number of (distinct) values that will be collected without error.")
+ .intConf
+ .createWithDefault(10000)
+
+ val RUN_SQL_ON_FILES = SQLConfigBuilder("spark.sql.runSQLOnFiles")
+ .internal()
+ .doc("When true, we could use `datasource`.`path` as table in SQL query.")
+ .booleanConf
+ .createWithDefault(true)
+
+ val WHOLESTAGE_CODEGEN_ENABLED = SQLConfigBuilder("spark.sql.codegen.wholeStage")
+ .internal()
+ .doc("When true, the whole stage (of multiple operators) will be compiled into single java" +
+ " method.")
+ .booleanConf
+ .createWithDefault(true)
+
+ val WHOLESTAGE_MAX_NUM_FIELDS = SQLConfigBuilder("spark.sql.codegen.maxFields")
+ .internal()
+ .doc("The maximum number of fields (including nested fields) that will be supported before" +
+ " deactivating whole-stage codegen.")
+ .intConf
+ .createWithDefault(200)
+
+ val FILES_MAX_PARTITION_BYTES = SQLConfigBuilder("spark.sql.files.maxPartitionBytes")
+ .doc("The maximum number of bytes to pack into a single partition when reading files.")
+ .longConf
+ .createWithDefault(128 * 1024 * 1024) // parquet.block.size
+
+ val FILES_OPEN_COST_IN_BYTES = SQLConfigBuilder("spark.sql.files.openCostInBytes")
+ .internal()
+ .doc("The estimated cost to open a file, measured by the number of bytes could be scanned in" +
+ " the same time. This is used when putting multiple files into a partition. It's better to" +
+ " over estimated, then the partitions with small files will be faster than partitions with" +
+ " bigger files (which is scheduled first).")
+ .longConf
+ .createWithDefault(4 * 1024 * 1024)
+
+ val EXCHANGE_REUSE_ENABLED = SQLConfigBuilder("spark.sql.exchange.reuse")
+ .internal()
+ .doc("When true, the planner will try to find out duplicated exchanges and re-use them.")
+ .booleanConf
+ .createWithDefault(true)
+
+ val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT =
+ SQLConfigBuilder("spark.sql.streaming.stateStore.minDeltasForSnapshot")
+ .internal()
+ .doc("Minimum number of state store delta files that needs to be generated before they " +
+ "consolidated into snapshots.")
+ .intConf
+ .createWithDefault(10)
+
+ val STATE_STORE_MIN_VERSIONS_TO_RETAIN =
+ SQLConfigBuilder("spark.sql.streaming.stateStore.minBatchesToRetain")
+ .internal()
+ .doc("Minimum number of versions of a state store's data to retain after cleaning.")
+ .intConf
+ .createWithDefault(2)
+
+ val CHECKPOINT_LOCATION = SQLConfigBuilder("spark.sql.streaming.checkpointLocation")
+ .doc("The default location for storing checkpoint data for continuously executing queries.")
+ .stringConf
+ .createOptional
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
@@ -568,7 +457,7 @@ object SQLConf {
*
* SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads).
*/
-class SQLConf extends Serializable with CatalystConf with ParserConf with Logging {
+private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
import SQLConf._
/** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */
@@ -581,14 +470,16 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin
def filesMaxPartitionBytes: Long = getConf(FILES_MAX_PARTITION_BYTES)
+ def filesOpenCostInBytes: Long = getConf(FILES_OPEN_COST_IN_BYTES)
+
def useCompression: Boolean = getConf(COMPRESS_CACHED)
def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION)
- def parquetFileScan: Boolean = getConf(PARQUET_FILE_SCAN)
-
def parquetCacheMetadata: Boolean = getConf(PARQUET_CACHE_METADATA)
+ def parquetVectorizedReaderEnabled: Boolean = getConf(PARQUET_VECTORIZED_READER_ENABLED)
+
def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE)
def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS)
@@ -613,6 +504,8 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin
def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED)
+ def wholeStageMaxNumFields: Int = getConf(WHOLESTAGE_MAX_NUM_FIELDS)
+
def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED)
def canonicalView: Boolean = getConf(CANONICAL_NATIVE_VIEW)
@@ -667,10 +560,6 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin
def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES)
- def supportQuotedId: Boolean = getConf(PARSER_SUPPORT_QUOTEDID)
-
- def supportSQL11ReservedKeywords: Boolean = getConf(PARSER_SUPPORT_SQL11_RESERVED_KEYWORDS)
-
override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL)
override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL)
@@ -694,7 +583,7 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin
}
/** Set the given Spark SQL configuration property. */
- def setConf[T](entry: SQLConfEntry[T], value: T): Unit = {
+ def setConf[T](entry: ConfigEntry[T], value: T): Unit = {
require(entry != null, "entry cannot be null")
require(value != null, s"value cannot be null for key: ${entry.key}")
require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered")
@@ -714,25 +603,35 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin
/**
* Return the value of Spark SQL configuration property for the given key. If the key is not set
- * yet, return `defaultValue`. This is useful when `defaultValue` in SQLConfEntry is not the
+ * yet, return `defaultValue`. This is useful when `defaultValue` in ConfigEntry is not the
* desired one.
*/
- def getConf[T](entry: SQLConfEntry[T], defaultValue: T): T = {
+ def getConf[T](entry: ConfigEntry[T], defaultValue: T): T = {
require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered")
Option(settings.get(entry.key)).map(entry.valueConverter).getOrElse(defaultValue)
}
/**
* Return the value of Spark SQL configuration property for the given key. If the key is not set
- * yet, return `defaultValue` in [[SQLConfEntry]].
+ * yet, return `defaultValue` in [[ConfigEntry]].
*/
- def getConf[T](entry: SQLConfEntry[T]): T = {
+ def getConf[T](entry: ConfigEntry[T]): T = {
require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered")
Option(settings.get(entry.key)).map(entry.valueConverter).orElse(entry.defaultValue).
getOrElse(throw new NoSuchElementException(entry.key))
}
/**
+ * Return the value of an optional Spark SQL configuration property for the given key. If the key
+ * is not set yet, throw an exception.
+ */
+ def getConf[T](entry: OptionalConfigEntry[T]): T = {
+ require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered")
+ Option(settings.get(entry.key)).map(entry.rawValueConverter).
+ getOrElse(throw new NoSuchElementException(entry.key))
+ }
+
+ /**
* Return the `string` value of Spark SQL configuration property for the given key. If the key is
* not set yet, return `defaultValue`.
*/
@@ -773,7 +672,7 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin
settings.remove(key)
}
- def unsetConf(entry: SQLConfEntry[_]): Unit = {
+ private[spark] def unsetConf(entry: ConfigEntry[_]): Unit = {
settings.remove(entry.key)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index e5f02caabc..69e3358d4e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -44,14 +44,19 @@ private[sql] class SessionState(ctx: SQLContext) {
lazy val experimentalMethods = new ExperimentalMethods
/**
- * Internal catalog for managing table and database states.
+ * Internal catalog for managing functions registered by the user.
*/
- lazy val catalog = new SessionCatalog(ctx.externalCatalog, conf)
+ lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy()
/**
- * Internal catalog for managing functions registered by the user.
+ * Internal catalog for managing table and database states.
*/
- lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy()
+ lazy val catalog =
+ new SessionCatalog(
+ ctx.externalCatalog,
+ ctx.functionResourceLoader,
+ functionRegistry,
+ conf)
/**
* Interface exposed to the user for registering user-defined functions.
@@ -62,9 +67,8 @@ private[sql] class SessionState(ctx: SQLContext) {
* Logical query plan analyzer for resolving unresolved attributes and relations.
*/
lazy val analyzer: Analyzer = {
- new Analyzer(catalog, functionRegistry, conf) {
+ new Analyzer(catalog, conf) {
override val extendedResolutionRules =
- python.ExtractPythonUDFs ::
PreInsertCastAndRename ::
DataSourceAnalysis ::
(if (conf.runSQLOnFile) new ResolveDataSource(ctx) :: Nil else Nil)
@@ -81,25 +85,13 @@ private[sql] class SessionState(ctx: SQLContext) {
/**
* Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
*/
- lazy val sqlParser: ParserInterface = new SparkQl(conf)
+ lazy val sqlParser: ParserInterface = SparkSqlParser
/**
* Planner that converts optimized logical plans to physical plans.
*/
- lazy val planner: SparkPlanner = new SparkPlanner(ctx.sparkContext, conf, experimentalMethods)
-
- /**
- * Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal
- * row format conversions as needed.
- */
- lazy val prepareForExecution = new RuleExecutor[SparkPlan] {
- override val batches: Seq[Batch] = Seq(
- Batch("Subquery", Once, PlanSubqueries(SessionState.this)),
- Batch("Add exchange", Once, EnsureRequirements(conf)),
- Batch("Whole stage codegen", Once, CollapseCodegenStages(conf)),
- Batch("Reuse duplicated exchanges", Once, ReuseExchange(conf))
- )
- }
+ def planner: SparkPlanner =
+ new SparkPlanner(ctx.sparkContext, conf, experimentalMethods.extraStrategies)
/**
* An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
@@ -111,5 +103,5 @@ private[sql] class SessionState(ctx: SQLContext) {
* Interface to start and stop [[org.apache.spark.sql.ContinuousQuery]]s.
*/
lazy val continuousQueryManager: ContinuousQueryManager = new ContinuousQueryManager(ctx)
-
}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index ca2d909e2c..948106fd06 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -100,11 +100,11 @@ abstract class JdbcDialect extends Serializable {
}
/**
- * Override connection specific properties to run before a select is made. This is in place to
- * allow dialects that need special treatment to optimize behavior.
- * @param connection The connection object
- * @param properties The connection properties. This is passed through from the relation.
- */
+ * Override connection specific properties to run before a select is made. This is in place to
+ * allow dialects that need special treatment to optimize behavior.
+ * @param connection The connection object
+ * @param properties The connection properties. This is passed through from the relation.
+ */
def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = {
}
@@ -126,7 +126,7 @@ object JdbcDialects {
/**
* Register a dialect for use on all new matching jdbc [[org.apache.spark.sql.DataFrame]].
- * Readding an existing dialect will cause a move-to-front.
+ * Reading an existing dialect will cause a move-to-front.
*
* @param dialect The new dialect.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 1e02354edf..4b9bf8daae 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -129,8 +129,17 @@ trait SchemaRelationProvider {
* Implemented by objects that can produce a streaming [[Source]] for a specific format or system.
*/
trait StreamSourceProvider {
+
+ /** Returns the name and schema of the source that can be used to continually read data. */
+ def sourceSchema(
+ sqlContext: SQLContext,
+ schema: Option[StructType],
+ providerName: String,
+ parameters: Map[String, String]): (String, StructType)
+
def createSource(
sqlContext: SQLContext,
+ metadataPath: String,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): Source
@@ -152,19 +161,19 @@ trait StreamSinkProvider {
@DeveloperApi
trait CreatableRelationProvider {
/**
- * Creates a relation with the given parameters based on the contents of the given
- * DataFrame. The mode specifies the expected behavior of createRelation when
- * data already exists.
- * Right now, there are three modes, Append, Overwrite, and ErrorIfExists.
- * Append mode means that when saving a DataFrame to a data source, if data already exists,
- * contents of the DataFrame are expected to be appended to existing data.
- * Overwrite mode means that when saving a DataFrame to a data source, if data already exists,
- * existing data is expected to be overwritten by the contents of the DataFrame.
- * ErrorIfExists mode means that when saving a DataFrame to a data source,
- * if data already exists, an exception is expected to be thrown.
- *
- * @since 1.3.0
- */
+ * Creates a relation with the given parameters based on the contents of the given
+ * DataFrame. The mode specifies the expected behavior of createRelation when
+ * data already exists.
+ * Right now, there are three modes, Append, Overwrite, and ErrorIfExists.
+ * Append mode means that when saving a DataFrame to a data source, if data already exists,
+ * contents of the DataFrame are expected to be appended to existing data.
+ * Overwrite mode means that when saving a DataFrame to a data source, if data already exists,
+ * existing data is expected to be overwritten by the contents of the DataFrame.
+ * ErrorIfExists mode means that when saving a DataFrame to a data source,
+ * if data already exists, an exception is expected to be thrown.
+ *
+ * @since 1.3.0
+ */
def createRelation(
sqlContext: SQLContext,
mode: SaveMode,
@@ -385,9 +394,9 @@ abstract class OutputWriter {
*
* @param location A [[FileCatalog]] that can enumerate the locations of all the files that comprise
* this relation.
- * @param partitionSchema The schmea of the columns (if any) that are used to partition the relation
+ * @param partitionSchema The schema of the columns (if any) that are used to partition the relation
* @param dataSchema The schema of any remaining columns. Note that if any partition columns are
- * present in the actual data files as well, they are removed.
+ * present in the actual data files as well, they are preserved.
* @param bucketSpec Describes the bucketing (hash-partitioning of the files by some column values).
* @param fileFormat A file format that can be used to read and write the data in files.
* @param options Configuration used when reading / writing data.
@@ -439,6 +448,15 @@ trait FileFormat {
files: Seq[FileStatus]): Option[StructType]
/**
+ * Prepares a read job and returns a potentially updated data source option [[Map]]. This method
+ * can be useful for collecting necessary global information for scanning input data.
+ */
+ def prepareRead(
+ sqlContext: SQLContext,
+ options: Map[String, String],
+ files: Seq[FileStatus]): Map[String, String] = options
+
+ /**
* Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can
* be put here. For example, user defined output committer can be configured here
* by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass.
@@ -449,33 +467,36 @@ trait FileFormat {
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory
- def buildInternalScan(
- sqlContext: SQLContext,
- dataSchema: StructType,
- requiredColumns: Array[String],
- filters: Array[Filter],
- bucketSet: Option[BitSet],
- inputFiles: Seq[FileStatus],
- broadcastedConf: Broadcast[SerializableConfiguration],
- options: Map[String, String]): RDD[InternalRow]
+ /**
+ * Returns whether this format support returning columnar batch or not.
+ *
+ * TODO: we should just have different traits for the different formats.
+ */
+ def supportBatch(sqlContext: SQLContext, dataSchema: StructType): Boolean = {
+ false
+ }
/**
* Returns a function that can be used to read a single file in as an Iterator of InternalRow.
*
+ * @param dataSchema The global data schema. It can be either specified by the user, or
+ * reconciled/merged from all underlying data files. If any partition columns
+ * are contained in the files, they are preserved in this schema.
* @param partitionSchema The schema of the partition column row that will be present in each
- * PartitionedFile. These columns should be prepended to the rows that
+ * PartitionedFile. These columns should be appended to the rows that
* are produced by the iterator.
- * @param dataSchema The schema of the data that should be output for each row. This may be a
- * subset of the columns that are present in the file if column pruning has
- * occurred.
+ * @param requiredSchema The schema of the data that should be output for each row. This may be a
+ * subset of the columns that are present in the file if column pruning has
+ * occurred.
* @param filters A set of filters than can optionally be used to reduce the number of rows output
* @param options A set of string -> string configuration options.
* @return
*/
def buildReader(
sqlContext: SQLContext,
- partitionSchema: StructType,
dataSchema: StructType,
+ partitionSchema: StructType,
+ requiredSchema: StructType,
filters: Seq[Filter],
options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = {
// TODO: Remove this default implementation when the other formats have been ported
@@ -572,10 +593,7 @@ class HDFSFileCatalog(
}
if (partitionPruningPredicates.nonEmpty) {
- val predicate =
- partitionPruningPredicates
- .reduceOption(expressions.And)
- .getOrElse(Literal(true))
+ val predicate = partitionPruningPredicates.reduce(expressions.And)
val boundPredicate = InterpretedPredicate.create(predicate.transform {
case a: AttributeReference =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala
index 2c5358cbd7..ba1facf11b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala
@@ -34,15 +34,23 @@ abstract class ContinuousQueryListener {
* @note This is called synchronously with
* [[org.apache.spark.sql.DataFrameWriter `DataFrameWriter.startStream()`]],
* that is, `onQueryStart` will be called on all listeners before
- * `DataFrameWriter.startStream()` returns the corresponding [[ContinuousQuery]].
+ * `DataFrameWriter.startStream()` returns the corresponding [[ContinuousQuery]]. Please
+ * don't block this method as it will block your query.
*/
- def onQueryStarted(queryStarted: QueryStarted)
+ def onQueryStarted(queryStarted: QueryStarted): Unit
- /** Called when there is some status update (ingestion rate updated, etc. */
- def onQueryProgress(queryProgress: QueryProgress)
+ /**
+ * Called when there is some status update (ingestion rate updated, etc.)
+ *
+ * @note This method is asynchronous. The status in [[ContinuousQuery]] will always be
+ * latest no matter when this method is called. Therefore, the status of [[ContinuousQuery]]
+ * may be changed before/when you process the event. E.g., you may find [[ContinuousQuery]]
+ * is terminated when you are processing [[QueryProgress]].
+ */
+ def onQueryProgress(queryProgress: QueryProgress): Unit
/** Called when a query is stopped, with or without error */
- def onQueryTerminated(queryTerminated: QueryTerminated)
+ def onQueryTerminated(queryTerminated: QueryTerminated): Unit
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index a6c819373b..5abd62cbc2 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -37,7 +37,6 @@ import org.apache.spark.SparkContext;
import org.apache.spark.api.java.function.*;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
-import org.apache.spark.sql.expressions.Aggregator;
import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.catalyst.encoders.OuterScopes;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
@@ -87,6 +86,16 @@ public class JavaDatasetSuite implements Serializable {
}
@Test
+ public void testToLocalIterator() {
+ List<String> data = Arrays.asList("hello", "world");
+ Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+ Iterator<String> iter = ds.toLocalIterator();
+ Assert.assertEquals("hello", iter.next());
+ Assert.assertEquals("world", iter.next());
+ Assert.assertFalse(iter.hasNext());
+ }
+
+ @Test
public void testCommonOperation() {
List<String> data = Arrays.asList("hello", "world");
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
@@ -319,14 +328,14 @@ public class JavaDatasetSuite implements Serializable {
Encoder<Tuple3<Integer, Long, String>> encoder3 =
Encoders.tuple(Encoders.INT(), Encoders.LONG(), Encoders.STRING());
List<Tuple3<Integer, Long, String>> data3 =
- Arrays.asList(new Tuple3<Integer, Long, String>(1, 2L, "a"));
+ Arrays.asList(new Tuple3<>(1, 2L, "a"));
Dataset<Tuple3<Integer, Long, String>> ds3 = context.createDataset(data3, encoder3);
Assert.assertEquals(data3, ds3.collectAsList());
Encoder<Tuple4<Integer, String, Long, String>> encoder4 =
Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING());
List<Tuple4<Integer, String, Long, String>> data4 =
- Arrays.asList(new Tuple4<Integer, String, Long, String>(1, "b", 2L, "a"));
+ Arrays.asList(new Tuple4<>(1, "b", 2L, "a"));
Dataset<Tuple4<Integer, String, Long, String>> ds4 = context.createDataset(data4, encoder4);
Assert.assertEquals(data4, ds4.collectAsList());
@@ -334,7 +343,7 @@ public class JavaDatasetSuite implements Serializable {
Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING(),
Encoders.BOOLEAN());
List<Tuple5<Integer, String, Long, String, Boolean>> data5 =
- Arrays.asList(new Tuple5<Integer, String, Long, String, Boolean>(1, "b", 2L, "a", true));
+ Arrays.asList(new Tuple5<>(1, "b", 2L, "a", true));
Dataset<Tuple5<Integer, String, Long, String, Boolean>> ds5 =
context.createDataset(data5, encoder5);
Assert.assertEquals(data5, ds5.collectAsList());
@@ -355,7 +364,7 @@ public class JavaDatasetSuite implements Serializable {
Encoders.tuple(Encoders.INT(),
Encoders.tuple(Encoders.STRING(), Encoders.STRING(), Encoders.LONG()));
List<Tuple2<Integer, Tuple3<String, String, Long>>> data2 =
- Arrays.asList(tuple2(1, new Tuple3<String, String, Long>("a", "b", 3L)));
+ Arrays.asList(tuple2(1, new Tuple3<>("a", "b", 3L)));
Dataset<Tuple2<Integer, Tuple3<String, String, Long>>> ds2 =
context.createDataset(data2, encoder2);
Assert.assertEquals(data2, ds2.collectAsList());
@@ -377,7 +386,7 @@ public class JavaDatasetSuite implements Serializable {
Encoders.tuple(Encoders.DOUBLE(), Encoders.DECIMAL(), Encoders.DATE(), Encoders.TIMESTAMP(),
Encoders.FLOAT());
List<Tuple5<Double, BigDecimal, Date, Timestamp, Float>> data =
- Arrays.asList(new Tuple5<Double, BigDecimal, Date, Timestamp, Float>(
+ Arrays.asList(new Tuple5<>(
1.7976931348623157E308, new BigDecimal("0.922337203685477589"),
Date.valueOf("1970-01-01"), new Timestamp(System.currentTimeMillis()), Float.MAX_VALUE));
Dataset<Tuple5<Double, BigDecimal, Date, Timestamp, Float>> ds =
@@ -385,59 +394,6 @@ public class JavaDatasetSuite implements Serializable {
Assert.assertEquals(data, ds.collectAsList());
}
- @Test
- public void testTypedAggregation() {
- Encoder<Tuple2<String, Integer>> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT());
- List<Tuple2<String, Integer>> data =
- Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3));
- Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder);
-
- KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupByKey(
- new MapFunction<Tuple2<String, Integer>, String>() {
- @Override
- public String call(Tuple2<String, Integer> value) throws Exception {
- return value._1();
- }
- },
- Encoders.STRING());
-
- Dataset<Tuple2<String, Integer>> agged =
- grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()));
- Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
-
- Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(
- new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()))
- .as(Encoders.tuple(Encoders.STRING(), Encoders.INT()));
- Assert.assertEquals(
- Arrays.asList(
- new Tuple2<>("a", 3),
- new Tuple2<>("b", 3)),
- agged2.collectAsList());
- }
-
- static class IntSumOf extends Aggregator<Tuple2<String, Integer>, Integer, Integer> {
-
- @Override
- public Integer zero() {
- return 0;
- }
-
- @Override
- public Integer reduce(Integer l, Tuple2<String, Integer> t) {
- return l + t._2();
- }
-
- @Override
- public Integer merge(Integer b1, Integer b2) {
- return b1 + b2;
- }
-
- @Override
- public Integer finish(Integer reduction) {
- return reduction;
- }
- }
-
public static class KryoSerializable {
String value;
@@ -498,6 +454,16 @@ public class JavaDatasetSuite implements Serializable {
Assert.assertEquals(data, ds.collectAsList());
}
+ @Test
+ public void testRandomSplit() {
+ List<String> data = Arrays.asList("hello", "world", "from", "spark");
+ Dataset<String> ds = context.createDataset(data, Encoders.STRING());
+ double[] arraySplit = {1, 2, 3};
+
+ List<Dataset<String>> randomSplit = ds.randomSplitAsList(arraySplit, 1);
+ Assert.assertEquals("wrong number of splits", randomSplit.size(), 3);
+ }
+
/**
* For testing error messages when creating an encoder on a private class. This is done
* here since we cannot create truly private classes in Scala.
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
new file mode 100644
index 0000000000..0e49f871de
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.sources;
+
+import java.util.Arrays;
+
+import scala.Tuple2;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.spark.api.java.function.MapFunction;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.KeyValueGroupedDataset;
+import org.apache.spark.sql.expressions.Aggregator;
+import org.apache.spark.sql.expressions.java.typed;
+
+/**
+ * Suite for testing the aggregate functionality of Datasets in Java.
+ */
+public class JavaDatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase {
+ @Test
+ public void testTypedAggregationAnonClass() {
+ KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+
+ Dataset<Tuple2<String, Integer>> agged = grouped.agg(new IntSumOf().toColumn());
+ Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
+
+ Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(new IntSumOf().toColumn())
+ .as(Encoders.tuple(Encoders.STRING(), Encoders.INT()));
+ Assert.assertEquals(
+ Arrays.asList(
+ new Tuple2<>("a", 3),
+ new Tuple2<>("b", 3)),
+ agged2.collectAsList());
+ }
+
+ static class IntSumOf extends Aggregator<Tuple2<String, Integer>, Integer, Integer> {
+ @Override
+ public Integer zero() {
+ return 0;
+ }
+
+ @Override
+ public Integer reduce(Integer l, Tuple2<String, Integer> t) {
+ return l + t._2();
+ }
+
+ @Override
+ public Integer merge(Integer b1, Integer b2) {
+ return b1 + b2;
+ }
+
+ @Override
+ public Integer finish(Integer reduction) {
+ return reduction;
+ }
+
+ @Override
+ public Encoder<Integer> bufferEncoder() {
+ return Encoders.INT();
+ }
+
+ @Override
+ public Encoder<Integer> outputEncoder() {
+ return Encoders.INT();
+ }
+ }
+
+ @Test
+ public void testTypedAggregationAverage() {
+ KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+ Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.avg(
+ new MapFunction<Tuple2<String, Integer>, Double>() {
+ public Double call(Tuple2<String, Integer> value) throws Exception {
+ return (double)(value._2() * 2);
+ }
+ }));
+ Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 6.0)), agged.collectAsList());
+ }
+
+ @Test
+ public void testTypedAggregationCount() {
+ KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+ Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.count(
+ new MapFunction<Tuple2<String, Integer>, Object>() {
+ public Object call(Tuple2<String, Integer> value) throws Exception {
+ return value;
+ }
+ }));
+ Assert.assertEquals(Arrays.asList(tuple2("a", 2), tuple2("b", 1)), agged.collectAsList());
+ }
+
+ @Test
+ public void testTypedAggregationSumDouble() {
+ KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+ Dataset<Tuple2<String, Double>> agged = grouped.agg(typed.sum(
+ new MapFunction<Tuple2<String, Integer>, Double>() {
+ public Double call(Tuple2<String, Integer> value) throws Exception {
+ return (double)value._2();
+ }
+ }));
+ Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 3.0)), agged.collectAsList());
+ }
+
+ @Test
+ public void testTypedAggregationSumLong() {
+ KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset();
+ Dataset<Tuple2<String, Long>> agged = grouped.agg(typed.sumLong(
+ new MapFunction<Tuple2<String, Integer>, Long>() {
+ public Long call(Tuple2<String, Integer> value) throws Exception {
+ return (long)value._2();
+ }
+ }));
+ Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
+ }
+}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java
new file mode 100644
index 0000000000..7863177093
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.sources;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+import scala.Tuple2;
+
+import org.junit.After;
+import org.junit.Before;
+
+import org.apache.spark.SparkContext;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.MapFunction;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.KeyValueGroupedDataset;
+import org.apache.spark.sql.test.TestSQLContext;
+
+/**
+ * Common test base shared across this and Java8DatasetAggregatorSuite.
+ */
+public class JavaDatasetAggregatorSuiteBase implements Serializable {
+ protected transient JavaSparkContext jsc;
+ protected transient TestSQLContext context;
+
+ @Before
+ public void setUp() {
+ // Trigger static initializer of TestData
+ SparkContext sc = new SparkContext("local[*]", "testing");
+ jsc = new JavaSparkContext(sc);
+ context = new TestSQLContext(sc);
+ context.loadTestData();
+ }
+
+ @After
+ public void tearDown() {
+ context.sparkContext().stop();
+ context = null;
+ jsc = null;
+ }
+
+ protected <T1, T2> Tuple2<T1, T2> tuple2(T1 t1, T2 t2) {
+ return new Tuple2<>(t1, t2);
+ }
+
+ protected KeyValueGroupedDataset<String, Tuple2<String, Integer>> generateGroupedDataset() {
+ Encoder<Tuple2<String, Integer>> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT());
+ List<Tuple2<String, Integer>> data =
+ Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3));
+ Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder);
+
+ return ds.groupByKey(
+ new MapFunction<Tuple2<String, Integer>, String>() {
+ @Override
+ public String call(Tuple2<String, Integer> value) throws Exception {
+ return value._1();
+ }
+ },
+ Encoders.STRING());
+ }
+}
+
diff --git a/sql/core/src/test/resources/unescaped-quotes.csv b/sql/core/src/test/resources/unescaped-quotes.csv
new file mode 100644
index 0000000000..7c68055575
--- /dev/null
+++ b/sql/core/src/test/resources/unescaped-quotes.csv
@@ -0,0 +1,2 @@
+"a"b,ccc,ddd
+ab,cc"c,ddd"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
index e34875471f..18e04c24a4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
@@ -141,26 +141,36 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
}
test("fill with map") {
- val df = Seq[(String, String, java.lang.Long, java.lang.Double, java.lang.Boolean)](
- (null, null, null, null, null)).toDF("a", "b", "c", "d", "e")
- checkAnswer(
- df.na.fill(Map(
- "a" -> "test",
- "c" -> 1,
- "d" -> 2.2,
- "e" -> false
- )),
- Row("test", null, 1, 2.2, false))
-
- // Test Java version
- checkAnswer(
- df.na.fill(Map(
- "a" -> "test",
- "c" -> 1,
- "d" -> 2.2,
- "e" -> false
- ).asJava),
- Row("test", null, 1, 2.2, false))
+ val df = Seq[(String, String, java.lang.Integer, java.lang.Long,
+ java.lang.Float, java.lang.Double, java.lang.Boolean)](
+ (null, null, null, null, null, null, null))
+ .toDF("stringFieldA", "stringFieldB", "integerField", "longField",
+ "floatField", "doubleField", "booleanField")
+
+ val fillMap = Map(
+ "stringFieldA" -> "test",
+ "integerField" -> 1,
+ "longField" -> 2L,
+ "floatField" -> 3.3f,
+ "doubleField" -> 4.4d,
+ "booleanField" -> false)
+
+ val expectedRow = Row("test", null, 1, 2L, 3.3f, 4.4d, false)
+
+ checkAnswer(df.na.fill(fillMap), expectedRow)
+ checkAnswer(df.na.fill(fillMap.asJava), expectedRow) // Test Java version
+
+ // Ensure replacement values are cast to the column data type.
+ checkAnswer(df.na.fill(Map(
+ "integerField" -> 1d,
+ "longField" -> 2d,
+ "floatField" -> 3d,
+ "doubleField" -> 4d)),
+ Row(null, null, 1, 2L, 3f, 4d, null))
+
+ // Ensure column types do not change. Columns that have null values replaced
+ // will no longer be flagged as nullable, so do not compare schemas directly.
+ assert(df.na.fill(fillMap).schema.fields.map(_.dataType) === df.schema.fields.map(_.dataType))
}
test("replace") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 86c6405522..e953a6e8ef 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1153,14 +1153,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
private def verifyNonExchangingAgg(df: DataFrame) = {
var atFirstAgg: Boolean = false
df.queryExecution.executedPlan.foreach {
- case agg: TungstenAggregate => {
+ case agg: TungstenAggregate =>
atFirstAgg = !atFirstAgg
- }
- case _ => {
+ case _ =>
if (atFirstAgg) {
fail("Should not have operators between the two aggregations")
}
- }
}
}
@@ -1170,12 +1168,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
private def verifyExchangingAgg(df: DataFrame) = {
var atFirstAgg: Boolean = false
df.queryExecution.executedPlan.foreach {
- case agg: TungstenAggregate => {
+ case agg: TungstenAggregate =>
if (atFirstAgg) {
fail("Should not have back to back Aggregates")
}
atFirstAgg = true
- }
case e: ShuffleExchange => atFirstAgg = false
case _ =>
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
new file mode 100644
index 0000000000..06584ec21e
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala
@@ -0,0 +1,299 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.TimeZone
+
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.StringType
+
+class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
+
+ import testImplicits._
+
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ TimeZone.setDefault(TimeZone.getTimeZone("UTC"))
+ }
+
+ override def afterEach(): Unit = {
+ super.beforeEach()
+ TimeZone.setDefault(null)
+ }
+
+ test("tumbling window groupBy statement") {
+ val df = Seq(
+ ("2016-03-27 19:39:34", 1, "a"),
+ ("2016-03-27 19:39:56", 2, "a"),
+ ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id")
+ checkAnswer(
+ df.groupBy(window($"time", "10 seconds"))
+ .agg(count("*").as("counts"))
+ .orderBy($"window.start".asc)
+ .select("counts"),
+ Seq(Row(1), Row(1), Row(1))
+ )
+ }
+
+ test("tumbling window groupBy statement with startTime") {
+ val df = Seq(
+ ("2016-03-27 19:39:34", 1, "a"),
+ ("2016-03-27 19:39:56", 2, "a"),
+ ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id")
+
+ checkAnswer(
+ df.groupBy(window($"time", "10 seconds", "10 seconds", "5 seconds"), $"id")
+ .agg(count("*").as("counts"))
+ .orderBy($"window.start".asc)
+ .select("counts"),
+ Seq(Row(1), Row(1), Row(1)))
+ }
+
+ test("tumbling window with multi-column projection") {
+ val df = Seq(
+ ("2016-03-27 19:39:34", 1, "a"),
+ ("2016-03-27 19:39:56", 2, "a"),
+ ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id")
+
+ checkAnswer(
+ df.select(window($"time", "10 seconds"), $"value")
+ .orderBy($"window.start".asc)
+ .select($"window.start".cast("string"), $"window.end".cast("string"), $"value"),
+ Seq(
+ Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4),
+ Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1),
+ Row("2016-03-27 19:39:50", "2016-03-27 19:40:00", 2)
+ )
+ )
+ }
+
+ test("sliding window grouping") {
+ val df = Seq(
+ ("2016-03-27 19:39:34", 1, "a"),
+ ("2016-03-27 19:39:56", 2, "a"),
+ ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id")
+
+ checkAnswer(
+ df.groupBy(window($"time", "10 seconds", "3 seconds", "0 second"))
+ .agg(count("*").as("counts"))
+ .orderBy($"window.start".asc)
+ .select($"window.start".cast("string"), $"window.end".cast("string"), $"counts"),
+ // 2016-03-27 19:39:27 UTC -> 4 bins
+ // 2016-03-27 19:39:34 UTC -> 3 bins
+ // 2016-03-27 19:39:56 UTC -> 3 bins
+ Seq(
+ Row("2016-03-27 19:39:18", "2016-03-27 19:39:28", 1),
+ Row("2016-03-27 19:39:21", "2016-03-27 19:39:31", 1),
+ Row("2016-03-27 19:39:24", "2016-03-27 19:39:34", 1),
+ Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", 2),
+ Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1),
+ Row("2016-03-27 19:39:33", "2016-03-27 19:39:43", 1),
+ Row("2016-03-27 19:39:48", "2016-03-27 19:39:58", 1),
+ Row("2016-03-27 19:39:51", "2016-03-27 19:40:01", 1),
+ Row("2016-03-27 19:39:54", "2016-03-27 19:40:04", 1))
+ )
+ }
+
+ test("sliding window projection") {
+ val df = Seq(
+ ("2016-03-27 19:39:34", 1, "a"),
+ ("2016-03-27 19:39:56", 2, "a"),
+ ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id")
+
+ checkAnswer(
+ df.select(window($"time", "10 seconds", "3 seconds", "0 second"), $"value")
+ .orderBy($"window.start".asc, $"value".desc).select("value"),
+ // 2016-03-27 19:39:27 UTC -> 4 bins
+ // 2016-03-27 19:39:34 UTC -> 3 bins
+ // 2016-03-27 19:39:56 UTC -> 3 bins
+ Seq(Row(4), Row(4), Row(4), Row(4), Row(1), Row(1), Row(1), Row(2), Row(2), Row(2))
+ )
+ }
+
+ test("windowing combined with explode expression") {
+ val df = Seq(
+ ("2016-03-27 19:39:34", 1, Seq("a", "b")),
+ ("2016-03-27 19:39:56", 2, Seq("a", "c", "d"))).toDF("time", "value", "ids")
+
+ checkAnswer(
+ df.select(window($"time", "10 seconds"), $"value", explode($"ids"))
+ .orderBy($"window.start".asc).select("value"),
+ // first window exploded to two rows for "a", and "b", second window exploded to 3 rows
+ Seq(Row(1), Row(1), Row(2), Row(2), Row(2))
+ )
+ }
+
+ test("null timestamps") {
+ val df = Seq(
+ ("2016-03-27 09:00:05", 1),
+ ("2016-03-27 09:00:32", 2),
+ (null, 3),
+ (null, 4)).toDF("time", "value")
+
+ checkDataset(
+ df.select(window($"time", "10 seconds"), $"value")
+ .orderBy($"window.start".asc)
+ .select("value")
+ .as[Int],
+ 1, 2) // null columns are dropped
+ }
+
+ test("time window joins") {
+ val df = Seq(
+ ("2016-03-27 09:00:05", 1),
+ ("2016-03-27 09:00:32", 2),
+ (null, 3),
+ (null, 4)).toDF("time", "value")
+
+ val df2 = Seq(
+ ("2016-03-27 09:00:02", 3),
+ ("2016-03-27 09:00:35", 6)).toDF("time", "othervalue")
+
+ checkAnswer(
+ df.select(window($"time", "10 seconds"), $"value").join(
+ df2.select(window($"time", "10 seconds"), $"othervalue"), Seq("window"))
+ .groupBy("window")
+ .agg((sum("value") + sum("othervalue")).as("total"))
+ .orderBy($"window.start".asc).select("total"),
+ Seq(Row(4), Row(8)))
+ }
+
+ test("negative timestamps") {
+ val df4 = Seq(
+ ("1970-01-01 00:00:02", 1),
+ ("1970-01-01 00:00:12", 2)).toDF("time", "value")
+ checkAnswer(
+ df4.select(window($"time", "10 seconds", "10 seconds", "5 seconds"), $"value")
+ .orderBy($"window.start".asc)
+ .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"),
+ Seq(
+ Row("1969-12-31 23:59:55", "1970-01-01 00:00:05", 1),
+ Row("1970-01-01 00:00:05", "1970-01-01 00:00:15", 2))
+ )
+ }
+
+ test("multiple time windows in a single operator throws nice exception") {
+ val df = Seq(
+ ("2016-03-27 09:00:02", 3),
+ ("2016-03-27 09:00:35", 6)).toDF("time", "value")
+ val e = intercept[AnalysisException] {
+ df.select(window($"time", "10 second"), window($"time", "15 second")).collect()
+ }
+ assert(e.getMessage.contains(
+ "Multiple time window expressions would result in a cartesian product"))
+ }
+
+ test("aliased windows") {
+ val df = Seq(
+ ("2016-03-27 19:39:34", 1, Seq("a", "b")),
+ ("2016-03-27 19:39:56", 2, Seq("a", "c", "d"))).toDF("time", "value", "ids")
+
+ checkAnswer(
+ df.select(window($"time", "10 seconds").as("time_window"), $"value")
+ .orderBy($"time_window.start".asc)
+ .select("value"),
+ Seq(Row(1), Row(2))
+ )
+ }
+
+ test("millisecond precision sliding windows") {
+ val df = Seq(
+ ("2016-03-27 09:00:00.41", 3),
+ ("2016-03-27 09:00:00.62", 6),
+ ("2016-03-27 09:00:00.715", 8)).toDF("time", "value")
+ checkAnswer(
+ df.groupBy(window($"time", "200 milliseconds", "40 milliseconds", "0 milliseconds"))
+ .agg(count("*").as("counts"))
+ .orderBy($"window.start".asc)
+ .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"counts"),
+ Seq(
+ Row("2016-03-27 09:00:00.24", "2016-03-27 09:00:00.44", 1),
+ Row("2016-03-27 09:00:00.28", "2016-03-27 09:00:00.48", 1),
+ Row("2016-03-27 09:00:00.32", "2016-03-27 09:00:00.52", 1),
+ Row("2016-03-27 09:00:00.36", "2016-03-27 09:00:00.56", 1),
+ Row("2016-03-27 09:00:00.4", "2016-03-27 09:00:00.6", 1),
+ Row("2016-03-27 09:00:00.44", "2016-03-27 09:00:00.64", 1),
+ Row("2016-03-27 09:00:00.48", "2016-03-27 09:00:00.68", 1),
+ Row("2016-03-27 09:00:00.52", "2016-03-27 09:00:00.72", 2),
+ Row("2016-03-27 09:00:00.56", "2016-03-27 09:00:00.76", 2),
+ Row("2016-03-27 09:00:00.6", "2016-03-27 09:00:00.8", 2),
+ Row("2016-03-27 09:00:00.64", "2016-03-27 09:00:00.84", 1),
+ Row("2016-03-27 09:00:00.68", "2016-03-27 09:00:00.88", 1))
+ )
+ }
+
+ private def withTempTable(f: String => Unit): Unit = {
+ val tableName = "temp"
+ Seq(
+ ("2016-03-27 19:39:34", 1),
+ ("2016-03-27 19:39:56", 2),
+ ("2016-03-27 19:39:27", 4)).toDF("time", "value").registerTempTable(tableName)
+ try {
+ f(tableName)
+ } finally {
+ sqlContext.dropTempTable(tableName)
+ }
+ }
+
+ test("time window in SQL with single string expression") {
+ withTempTable { table =>
+ checkAnswer(
+ sqlContext.sql(s"""select window(time, "10 seconds"), value from $table""")
+ .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"),
+ Seq(
+ Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4),
+ Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1),
+ Row("2016-03-27 19:39:50", "2016-03-27 19:40:00", 2)
+ )
+ )
+ }
+ }
+
+ test("time window in SQL with with two expressions") {
+ withTempTable { table =>
+ checkAnswer(
+ sqlContext.sql(
+ s"""select window(time, "10 seconds", 10000000), value from $table""")
+ .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"),
+ Seq(
+ Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4),
+ Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1),
+ Row("2016-03-27 19:39:50", "2016-03-27 19:40:00", 2)
+ )
+ )
+ }
+ }
+
+ test("time window in SQL with with three expressions") {
+ withTempTable { table =>
+ checkAnswer(
+ sqlContext.sql(
+ s"""select window(time, "10 seconds", 10000000, "5 seconds"), value from $table""")
+ .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"),
+ Seq(
+ Row("2016-03-27 19:39:25", "2016-03-27 19:39:35", 1),
+ Row("2016-03-27 19:39:25", "2016-03-27 19:39:35", 4),
+ Row("2016-03-27 19:39:55", "2016-03-27 19:40:05", 2)
+ )
+ )
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 84770169f0..3a7215ee39 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -19,109 +19,83 @@ package org.apache.spark.sql
import scala.language.postfixOps
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
+import org.apache.spark.sql.expressions.scala.typed
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
-/** An `Aggregator` that adds up any numeric type returned by the given function. */
-class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] {
- val numeric = implicitly[Numeric[N]]
-
- override def zero: N = numeric.zero
-
- override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
-
- override def merge(b1: N, b2: N): N = numeric.plus(b1, b2)
-
- override def finish(reduction: N): N = reduction
-}
-
-object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] {
- override def zero: (Long, Long) = (0, 0)
-
- override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = {
- (countAndSum._1 + 1, countAndSum._2 + input._2)
- }
-
- override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = {
- (b1._1 + b2._1, b1._2 + b2._2)
- }
-
- override def finish(countAndSum: (Long, Long)): Double = countAndSum._2 / countAndSum._1
-}
object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] {
-
override def zero: (Long, Long) = (0, 0)
-
override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = {
(countAndSum._1 + 1, countAndSum._2 + input._2)
}
-
override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = {
(b1._1 + b2._1, b1._2 + b2._2)
}
-
override def finish(reduction: (Long, Long)): (Long, Long) = reduction
+ override def bufferEncoder: Encoder[(Long, Long)] = Encoders.product[(Long, Long)]
+ override def outputEncoder: Encoder[(Long, Long)] = Encoders.product[(Long, Long)]
}
+
case class AggData(a: Int, b: String)
+
object ClassInputAgg extends Aggregator[AggData, Int, Int] {
- /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */
override def zero: Int = 0
-
- /**
- * Combine two values to produce a new value. For performance, the function may modify `b` and
- * return it instead of constructing new object for b.
- */
override def reduce(b: Int, a: AggData): Int = b + a.a
-
- /**
- * Transform the output of the reduction.
- */
override def finish(reduction: Int): Int = reduction
-
- /**
- * Merge two intermediate values
- */
override def merge(b1: Int, b2: Int): Int = b1 + b2
+ override def bufferEncoder: Encoder[Int] = Encoders.scalaInt
+ override def outputEncoder: Encoder[Int] = Encoders.scalaInt
}
+
object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] {
- /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */
override def zero: (Int, AggData) = 0 -> AggData(0, "0")
-
- /**
- * Combine two values to produce a new value. For performance, the function may modify `b` and
- * return it instead of constructing new object for b.
- */
override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a)
-
- /**
- * Transform the output of the reduction.
- */
override def finish(reduction: (Int, AggData)): Int = reduction._1
-
- /**
- * Merge two intermediate values
- */
override def merge(b1: (Int, AggData), b2: (Int, AggData)): (Int, AggData) =
(b1._1 + b2._1, b1._2)
+ override def bufferEncoder: Encoder[(Int, AggData)] = Encoders.product[(Int, AggData)]
+ override def outputEncoder: Encoder[Int] = Encoders.scalaInt
+}
+
+
+object NameAgg extends Aggregator[AggData, String, String] {
+ def zero: String = ""
+ def reduce(b: String, a: AggData): String = a.b + b
+ def merge(b1: String, b2: String): String = b1 + b2
+ def finish(r: String): String = r
+ override def bufferEncoder: Encoder[String] = Encoders.STRING
+ override def outputEncoder: Encoder[String] = Encoders.STRING
+}
+
+
+class ParameterizedTypeSum[IN, OUT : Numeric : Encoder](f: IN => OUT)
+ extends Aggregator[IN, OUT, OUT] {
+
+ private val numeric = implicitly[Numeric[OUT]]
+ override def zero: OUT = numeric.zero
+ override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a))
+ override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2)
+ override def finish(reduction: OUT): OUT = reduction
+ override def bufferEncoder: Encoder[OUT] = implicitly[Encoder[OUT]]
+ override def outputEncoder: Encoder[OUT] = implicitly[Encoder[OUT]]
}
+
class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
import testImplicits._
- def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] =
- new SumOf(f).toColumn
-
test("typed aggregation: TypedAggregator") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
checkDataset(
- ds.groupByKey(_._1).agg(sum(_._2)),
- ("a", 30), ("b", 3), ("c", 1))
+ ds.groupByKey(_._1).agg(typed.sum(_._2)),
+ ("a", 30.0), ("b", 3.0), ("c", 1.0))
}
test("typed aggregation: TypedAggregator, expr, expr") {
@@ -129,20 +103,10 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
checkDataset(
ds.groupByKey(_._1).agg(
- sum(_._2),
+ typed.sum(_._2),
expr("sum(_2)").as[Long],
count("*")),
- ("a", 30, 30L, 2L), ("b", 3, 3L, 2L), ("c", 1, 1L, 1L))
- }
-
- test("typed aggregation: complex case") {
- val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS()
-
- checkDataset(
- ds.groupByKey(_._1).agg(
- expr("avg(_2)").as[Double],
- TypedAverage.toColumn),
- ("a", 2.0, 2.0), ("b", 3.0, 3.0))
+ ("a", 30.0, 30L, 2L), ("b", 3.0, 3L, 2L), ("c", 1.0, 1L, 1L))
}
test("typed aggregation: complex result type") {
@@ -159,11 +123,11 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
val ds = Seq(1, 3, 2, 5).toDS()
checkDataset(
- ds.select(sum((i: Int) => i)),
- 11)
+ ds.select(typed.sum((i: Int) => i)),
+ 11.0)
checkDataset(
- ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)),
- 11 -> 22)
+ ds.select(typed.sum((i: Int) => i), typed.sum((i: Int) => i * 2)),
+ 11.0 -> 22.0)
}
test("typed aggregation: class input") {
@@ -206,4 +170,34 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
ds.groupByKey(_.b).agg(ComplexBufferAgg.toColumn),
("one", 1), ("two", 1))
}
+
+ test("typed aggregate: avg, count, sum") {
+ val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS()
+ checkDataset(
+ ds.groupByKey(_._1).agg(
+ typed.avg(_._2), typed.count(_._2), typed.sum(_._2), typed.sumLong(_._2)),
+ ("a", 2.0, 2L, 4.0, 4L), ("b", 3.0, 1L, 3.0, 3L))
+ }
+
+ test("generic typed sum") {
+ val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS()
+ checkDataset(
+ ds.groupByKey(_._1)
+ .agg(new ParameterizedTypeSum[(String, Int), Double](_._2.toDouble).toColumn),
+ ("a", 4.0), ("b", 3.0))
+
+ checkDataset(
+ ds.groupByKey(_._1)
+ .agg(new ParameterizedTypeSum((x: (String, Int)) => x._2.toInt).toColumn),
+ ("a", 4), ("b", 3))
+ }
+
+ test("SPARK-12555 - result should not be corrupted after input columns are reordered") {
+ val ds = sql("SELECT 'Some String' AS b, 1279869254 AS a").as[AggData]
+
+ checkDataset(
+ ds.groupByKey(_.a).agg(NameAgg.toColumn),
+ (1279869254, "Some String"))
+ }
+
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
new file mode 100644
index 0000000000..5f3dd906fe
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.types.StringType
+import org.apache.spark.util.Benchmark
+
+/**
+ * Benchmark for Dataset typed operations comparing with DataFrame and RDD versions.
+ */
+object DatasetBenchmark {
+
+ case class Data(l: Long, s: String)
+
+ def backToBackMap(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = {
+ import sqlContext.implicits._
+
+ val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
+ val benchmark = new Benchmark("back-to-back map", numRows)
+
+ val func = (d: Data) => Data(d.l + 1, d.s)
+ benchmark.addCase("Dataset") { iter =>
+ var res = df.as[Data]
+ var i = 0
+ while (i < numChains) {
+ res = res.map(func)
+ i += 1
+ }
+ res.queryExecution.toRdd.foreach(_ => Unit)
+ }
+
+ benchmark.addCase("DataFrame") { iter =>
+ var res = df
+ var i = 0
+ while (i < numChains) {
+ res = res.select($"l" + 1 as "l")
+ i += 1
+ }
+ res.queryExecution.toRdd.foreach(_ => Unit)
+ }
+
+ val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
+ benchmark.addCase("RDD") { iter =>
+ var res = rdd
+ var i = 0
+ while (i < numChains) {
+ res = rdd.map(func)
+ i += 1
+ }
+ res.foreach(_ => Unit)
+ }
+
+ benchmark
+ }
+
+ def backToBackFilter(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = {
+ import sqlContext.implicits._
+
+ val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
+ val benchmark = new Benchmark("back-to-back filter", numRows)
+
+ val func = (d: Data, i: Int) => d.l % (100L + i) == 0L
+ val funcs = 0.until(numChains).map { i =>
+ (d: Data) => func(d, i)
+ }
+ benchmark.addCase("Dataset") { iter =>
+ var res = df.as[Data]
+ var i = 0
+ while (i < numChains) {
+ res = res.filter(funcs(i))
+ i += 1
+ }
+ res.queryExecution.toRdd.foreach(_ => Unit)
+ }
+
+ benchmark.addCase("DataFrame") { iter =>
+ var res = df
+ var i = 0
+ while (i < numChains) {
+ res = res.filter($"l" % (100L + i) === 0L)
+ i += 1
+ }
+ res.queryExecution.toRdd.foreach(_ => Unit)
+ }
+
+ val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
+ benchmark.addCase("RDD") { iter =>
+ var res = rdd
+ var i = 0
+ while (i < numChains) {
+ res = rdd.filter(funcs(i))
+ i += 1
+ }
+ res.foreach(_ => Unit)
+ }
+
+ benchmark
+ }
+
+ def main(args: Array[String]): Unit = {
+ val sparkContext = new SparkContext("local[*]", "Dataset benchmark")
+ val sqlContext = new SQLContext(sparkContext)
+
+ val numRows = 10000000
+ val numChains = 10
+
+ val benchmark = backToBackMap(sqlContext, numRows, numChains)
+ val benchmark2 = backToBackFilter(sqlContext, numRows, numChains)
+
+ /*
+ Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4
+ Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
+ back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ Dataset 902 / 995 11.1 90.2 1.0X
+ DataFrame 132 / 167 75.5 13.2 6.8X
+ RDD 216 / 237 46.3 21.6 4.2X
+ */
+ benchmark.run()
+
+ /*
+ back-to-back filter: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ Dataset 585 / 628 17.1 58.5 1.0X
+ DataFrame 62 / 80 160.7 6.2 9.4X
+ RDD 205 / 220 48.7 20.5 2.8X
+ */
+ benchmark2.run()
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 553bc436a6..d074535bf6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -23,6 +23,7 @@ import java.sql.{Date, Timestamp}
import scala.language.postfixOps
import org.apache.spark.sql.catalyst.encoders.OuterScopes
+import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
@@ -71,6 +72,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(ds.first() == item)
assert(ds.take(1).head == item)
assert(ds.takeAsList(1).get(0) == item)
+ assert(ds.toLocalIterator().next() === item)
}
test("coalesce, repartition") {
@@ -601,6 +603,29 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
TupleClass(1, "a")
)
}
+
+ test("isStreaming returns false for static Dataset") {
+ val data = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
+ assert(!data.isStreaming, "static Dataset returned true for 'isStreaming'.")
+ }
+
+ test("isStreaming returns true for streaming Dataset") {
+ val data = MemoryStream[Int].toDS()
+ assert(data.isStreaming, "streaming Dataset returned false for 'isStreaming'.")
+ }
+
+ test("isStreaming returns true after static and streaming Dataset join") {
+ val static = Seq(("a", 1), ("b", 2), ("c", 3)).toDF("a", "b")
+ val streaming = MemoryStream[Int].toDS().toDF("b")
+ val df = streaming.join(static, Seq("b"))
+ assert(df.isStreaming, "streaming Dataset returned false for 'isStreaming'.")
+ }
+
+ test("SPARK-14554: Dataset.map may generate wrong java code for wide table") {
+ val wideDF = sqlContext.range(10).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*)
+ // Make sure the generated code for this plan can compile and execute.
+ checkDataset(wideDF.map(_.getLong(0)), 0L until 10 : _*)
+ }
}
case class OtherTuple(_1: String, _2: Int)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 5af1a4fcd7..a87a41c126 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -41,7 +41,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
assert(planned.size === 1)
}
- def assertJoin(sqlString: String, c: Class[_]): Any = {
+ def assertJoin(pair: (String, Class[_])): Any = {
+ val (sqlString, c) = pair
val df = sql(sqlString)
val physical = df.queryExecution.sparkPlan
val operators = physical.collect {
@@ -53,8 +54,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
}
assert(operators.size === 1)
- if (operators(0).getClass() != c) {
- fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical")
+ if (operators.head.getClass != c) {
+ fail(s"$sqlString expected operator: $c, but got ${operators.head}\n physical: \n$physical")
}
}
@@ -93,8 +94,10 @@ class JoinSuite extends QueryTest with SharedSQLContext {
("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)",
classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)",
- classOf[BroadcastNestedLoopJoin])
- ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData ANTI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]),
+ ("SELECT * FROM testData LEFT ANTI JOIN testData2", classOf[BroadcastNestedLoopJoin])
+ ).foreach(assertJoin)
}
}
@@ -114,7 +117,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
classOf[BroadcastHashJoin]),
("SELECT * FROM testData join testData2 ON key = a where key = 2",
classOf[BroadcastHashJoin])
- ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+ ).foreach(assertJoin)
sql("UNCACHE TABLE testData")
}
@@ -129,7 +132,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
classOf[BroadcastHashJoin]),
("SELECT * FROM testData right join testData2 ON key = a and key = 2",
classOf[BroadcastHashJoin])
- ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+ ).foreach(assertJoin)
sql("UNCACHE TABLE testData")
}
@@ -329,8 +332,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
}
test("full outer join") {
- upperCaseData.where('N <= 4).registerTempTable("left")
- upperCaseData.where('N >= 3).registerTempTable("right")
+ upperCaseData.where('N <= 4).registerTempTable("`left`")
+ upperCaseData.where('N >= 3).registerTempTable("`right`")
val left = UnresolvedRelation(TableIdentifier("left"), None)
val right = UnresolvedRelation(TableIdentifier("right"), None)
@@ -419,25 +422,22 @@ class JoinSuite extends QueryTest with SharedSQLContext {
Row(null, 10))
}
- test("broadcasted left semi join operator selection") {
+ test("broadcasted existence join operator selection") {
sqlContext.cacheManager.clearCache()
sql("CACHE TABLE testData")
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") {
Seq(
- ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
- classOf[BroadcastHashJoin])
- ).foreach {
- case (query, joinClass) => assertJoin(query, joinClass)
- }
+ ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[BroadcastHashJoin]),
+ ("SELECT * FROM testData ANT JOIN testData2 ON key = a", classOf[BroadcastHashJoin])
+ ).foreach(assertJoin)
}
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
Seq(
- ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[ShuffledHashJoin])
- ).foreach {
- case (query, joinClass) => assertJoin(query, joinClass)
- }
+ ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]),
+ ("SELECT * FROM testData LEFT ANTI JOIN testData2 ON key = a", classOf[ShuffledHashJoin])
+ ).foreach(assertJoin)
}
sql("UNCACHE TABLE testData")
@@ -489,7 +489,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
classOf[BroadcastNestedLoopJoin]),
("SELECT * FROM testData full JOIN testData2 WHERE (key * a != key + a)",
classOf[BroadcastNestedLoopJoin])
- ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+ ).foreach(assertJoin)
checkAnswer(
sql(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala
new file mode 100644
index 0000000000..0d18a645f6
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.concurrent.TimeUnit
+
+import scala.concurrent.duration._
+
+import org.apache.spark.SparkFunSuite
+
+class ProcessingTimeSuite extends SparkFunSuite {
+
+ test("create") {
+ assert(ProcessingTime(10.seconds).intervalMs === 10 * 1000)
+ assert(ProcessingTime.create(10, TimeUnit.SECONDS).intervalMs === 10 * 1000)
+ assert(ProcessingTime("1 minute").intervalMs === 60 * 1000)
+ assert(ProcessingTime("interval 1 minute").intervalMs === 60 * 1000)
+
+ intercept[IllegalArgumentException] { ProcessingTime(null: String) }
+ intercept[IllegalArgumentException] { ProcessingTime("") }
+ intercept[IllegalArgumentException] { ProcessingTime("invalid") }
+ intercept[IllegalArgumentException] { ProcessingTime("1 month") }
+ intercept[IllegalArgumentException] { ProcessingTime("1 year") }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index a1b45ca7eb..826862835a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -28,9 +28,11 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.execution.{LogicalRDD, Queryable}
+import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.execution.streaming.MemoryPlan
+import org.apache.spark.sql.types.ObjectType
abstract class QueryTest extends PlanTest {
@@ -90,7 +92,7 @@ abstract class QueryTest extends PlanTest {
s"""
|Exception collecting dataset as objects
|${ds.resolvedTEncoder}
- |${ds.resolvedTEncoder.fromRowExpression.treeString}
+ |${ds.resolvedTEncoder.deserializer.treeString}
|${ds.queryExecution}
""".stripMargin, e)
}
@@ -105,11 +107,11 @@ abstract class QueryTest extends PlanTest {
val expected = expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted
val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted
- val comparision = sideBySide("expected" +: expected, "spark" +: actual).mkString("\n")
+ val comparison = sideBySide("expected" +: expected, "spark" +: actual).mkString("\n")
fail(
s"""Decoded objects do not match expected objects:
- |$comparision
- |${ds.resolvedTEncoder.fromRowExpression.treeString}
+ |$comparison
+ |${ds.resolvedTEncoder.deserializer.treeString}
""".stripMargin)
}
}
@@ -180,9 +182,9 @@ abstract class QueryTest extends PlanTest {
}
/**
- * Asserts that a given [[Queryable]] will be executed using the given number of cached results.
+ * Asserts that a given [[Dataset]] will be executed using the given number of cached results.
*/
- def assertCached(query: Queryable, numCachedTables: Int = 1): Unit = {
+ def assertCached(query: Dataset[_], numCachedTables: Int = 1): Unit = {
val planWithCaching = query.queryExecution.withCachedData
val cachedData = planWithCaching collect {
case cached: InMemoryRelation => cached
@@ -198,13 +200,12 @@ abstract class QueryTest extends PlanTest {
val logicalPlan = df.queryExecution.analyzed
// bypass some cases that we can't handle currently.
logicalPlan.transform {
- case _: MapPartitions => return
- case _: MapGroups => return
- case _: AppendColumns => return
- case _: CoGroup => return
+ case _: ObjectOperator => return
case _: LogicalRelation => return
+ case _: MemoryPlan => return
}.transformAllExpressions {
case a: ImperativeAggregate => return
+ case Literal(_, _: ObjectType) => return
}
// bypass hive tests before we fix all corner cases in hive module.
@@ -286,9 +287,9 @@ abstract class QueryTest extends PlanTest {
}
/**
- * Asserts that a given [[Queryable]] does not have missing inputs in all the analyzed plans.
- */
- def assertEmptyMissingInput(query: Queryable): Unit = {
+ * Asserts that a given [[Dataset]] does not have missing inputs in all the analyzed plans.
+ */
+ def assertEmptyMissingInput(query: Dataset[_]): Unit = {
assert(query.queryExecution.analyzed.missingInput.isEmpty,
s"The analyzed logical plan has missing inputs: ${query.queryExecution.analyzed}")
assert(query.queryExecution.optimizedPlan.missingInput.isEmpty,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index c958eac266..cdd404d699 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
+import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, CartesianProduct, SortMergeJoin}
import org.apache.spark.sql.functions._
@@ -56,12 +57,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("show functions") {
def getFunctions(pattern: String): Seq[Row] = {
- val regex = java.util.regex.Pattern.compile(pattern)
- sqlContext.sessionState.functionRegistry.listFunction()
- .filter(regex.matcher(_).matches()).map(Row(_))
+ StringUtils.filterPattern(sqlContext.sessionState.functionRegistry.listFunction(), pattern)
+ .map(Row(_))
}
- checkAnswer(sql("SHOW functions"), getFunctions(".*"))
- Seq("^c.*", ".*e$", "log.*", ".*date.*").foreach { pattern =>
+ checkAnswer(sql("SHOW functions"), getFunctions("*"))
+ Seq("^c*", "*e$", "log*", "*date*").foreach { pattern =>
+ // For the pattern part, only '*' and '|' are allowed as wildcards.
+ // For '*', we need to replace it to '.*'.
checkAnswer(sql(s"SHOW FUNCTIONS '$pattern'"), getFunctions(pattern))
}
}
@@ -87,6 +89,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
"Function: abcadf not found.")
}
+ test("SPARK-14415: All functions should have own descriptions") {
+ for (f <- sqlContext.sessionState.functionRegistry.listFunction()) {
+ if (!Seq("cube", "grouping", "grouping_id", "rollup", "window").contains(f)) {
+ checkExistence(sql(s"describe function `$f`"), false, "To be added.")
+ }
+ }
+ }
+
test("SPARK-6743: no columns from cache") {
Seq(
(83, 0, 38),
@@ -1656,7 +1666,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
val e2 = intercept[AnalysisException] {
sql("select interval 23 nanosecond")
}
- assert(e2.message.contains("cannot recognize input near"))
+ assert(e2.message.contains("No interval can be constructed"))
}
test("SPARK-8945: add and subtract expressions for interval type") {
@@ -1817,12 +1827,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
val e1 = intercept[AnalysisException] {
sql("select * from in_valid_table")
}
- assert(e1.message.contains("Table not found"))
+ assert(e1.message.contains("Table or View not found"))
val e2 = intercept[AnalysisException] {
sql("select * from no_db.no_table").show()
}
- assert(e2.message.contains("Table not found"))
+ assert(e2.message.contains("Table or View not found"))
val e3 = intercept[AnalysisException] {
sql("select * from json.invalid_file")
@@ -2228,6 +2238,88 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead")
}
+ test("grouping and grouping_id in having") {
+ checkAnswer(
+ sql("select course, year from courseSales group by cube(course, year)" +
+ " having grouping(year) = 1 and grouping_id(course, year) > 0"),
+ Row("Java", null) ::
+ Row("dotNET", null) ::
+ Row(null, null) :: Nil
+ )
+
+ var error = intercept[AnalysisException] {
+ sql("select course, year from courseSales group by course, year" +
+ " having grouping(course) > 0")
+ }
+ assert(error.getMessage contains
+ "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
+ error = intercept[AnalysisException] {
+ sql("select course, year from courseSales group by course, year" +
+ " having grouping_id(course, year) > 0")
+ }
+ assert(error.getMessage contains
+ "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
+ error = intercept[AnalysisException] {
+ sql("select course, year from courseSales group by cube(course, year)" +
+ " having grouping__id > 0")
+ }
+ assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead")
+ }
+
+ test("grouping and grouping_id in sort") {
+ checkAnswer(
+ sql("select course, year, grouping(course), grouping(year) from courseSales" +
+ " group by cube(course, year) order by grouping_id(course, year), course, year"),
+ Row("Java", 2012, 0, 0) ::
+ Row("Java", 2013, 0, 0) ::
+ Row("dotNET", 2012, 0, 0) ::
+ Row("dotNET", 2013, 0, 0) ::
+ Row("Java", null, 0, 1) ::
+ Row("dotNET", null, 0, 1) ::
+ Row(null, 2012, 1, 0) ::
+ Row(null, 2013, 1, 0) ::
+ Row(null, null, 1, 1) :: Nil
+ )
+
+ checkAnswer(
+ sql("select course, year, grouping_id(course, year) from courseSales" +
+ " group by cube(course, year) order by grouping(course), grouping(year), course, year"),
+ Row("Java", 2012, 0) ::
+ Row("Java", 2013, 0) ::
+ Row("dotNET", 2012, 0) ::
+ Row("dotNET", 2013, 0) ::
+ Row("Java", null, 1) ::
+ Row("dotNET", null, 1) ::
+ Row(null, 2012, 2) ::
+ Row(null, 2013, 2) ::
+ Row(null, null, 3) :: Nil
+ )
+
+ var error = intercept[AnalysisException] {
+ sql("select course, year from courseSales group by course, year" +
+ " order by grouping(course)")
+ }
+ assert(error.getMessage contains
+ "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
+ error = intercept[AnalysisException] {
+ sql("select course, year from courseSales group by course, year" +
+ " order by grouping_id(course, year)")
+ }
+ assert(error.getMessage contains
+ "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
+ error = intercept[AnalysisException] {
+ sql("select course, year from courseSales group by cube(course, year)" +
+ " order by grouping__id")
+ }
+ assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead")
+ }
+
+ test("filter on a grouping column that is not presented in SELECT") {
+ checkAnswer(
+ sql("select count(1) from (select 1 as a) t group by a having a > 0"),
+ Row(1) :: Nil)
+ }
+
test("SPARK-13056: Null in map value causes NPE") {
val df = Seq(1 -> Map("abc" -> "somestring", "cba" -> null)).toDF("key", "value")
withTempTable("maptest") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
index 4ca739450c..6ccc99fe17 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
@@ -36,6 +36,7 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.util.Utils
@@ -66,9 +67,9 @@ import org.apache.spark.util.Utils
trait StreamTest extends QueryTest with Timeouts {
implicit class RichSource(s: Source) {
- def toDF(): DataFrame = Dataset.ofRows(sqlContext, StreamingRelation(s))
+ def toDF(): DataFrame = Dataset.ofRows(sqlContext, StreamingExecutionRelation(s))
- def toDS[A: Encoder](): Dataset[A] = Dataset(sqlContext, StreamingRelation(s))
+ def toDS[A: Encoder](): Dataset[A] = Dataset(sqlContext, StreamingExecutionRelation(s))
}
/** How long to wait for an active stream to catch up when checking a result. */
@@ -116,15 +117,30 @@ trait StreamTest extends QueryTest with Timeouts {
def apply[A : Encoder](data: A*): CheckAnswerRows = {
val encoder = encoderFor[A]
val toExternalRow = RowEncoder(encoder.schema)
- CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))))
+ CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), false)
}
- def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows)
+ def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false)
}
- case class CheckAnswerRows(expectedAnswer: Seq[Row])
+ /**
+ * Checks to make sure that the current data stored in the sink matches the `expectedAnswer`.
+ * This operation automatically blocks until all added data has been processed.
+ */
+ object CheckLastBatch {
+ def apply[A : Encoder](data: A*): CheckAnswerRows = {
+ val encoder = encoderFor[A]
+ val toExternalRow = RowEncoder(encoder.schema)
+ CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), true)
+ }
+
+ def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true)
+ }
+
+ case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean)
extends StreamAction with StreamMustBeRunning {
- override def toString: String = s"CheckAnswer: ${expectedAnswer.mkString(",")}"
+ override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}"
+ private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer"
}
/** Stops the stream. It must currently be running. */
@@ -224,11 +240,8 @@ trait StreamTest extends QueryTest with Timeouts {
""".stripMargin
def verify(condition: => Boolean, message: String): Unit = {
- try {
- Assertions.assert(condition)
- } catch {
- case NonFatal(e) =>
- failTest(message, e)
+ if (!condition) {
+ failTest(message)
}
}
@@ -265,7 +278,7 @@ trait StreamTest extends QueryTest with Timeouts {
}
val testThread = Thread.currentThread()
- val metadataRoot = Utils.createTempDir("streaming.metadata").getCanonicalPath
+ val metadataRoot = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
try {
startedTest.foreach { action =>
@@ -276,7 +289,11 @@ trait StreamTest extends QueryTest with Timeouts {
currentStream =
sqlContext
.streams
- .startQuery(StreamExecution.nextName, metadataRoot, stream, sink)
+ .startQuery(
+ StreamExecution.nextName,
+ metadataRoot,
+ stream,
+ sink)
.asInstanceOf[StreamExecution]
currentStream.microBatchThread.setUncaughtExceptionHandler(
new UncaughtExceptionHandler {
@@ -351,7 +368,7 @@ trait StreamTest extends QueryTest with Timeouts {
case a: AddData =>
awaiting.put(a.source, a.addData())
- case CheckAnswerRows(expectedAnswer) =>
+ case CheckAnswerRows(expectedAnswer, lastOnly) =>
verify(currentStream != null, "stream not running")
// Block until all data added has been processed
@@ -361,12 +378,12 @@ trait StreamTest extends QueryTest with Timeouts {
}
}
- val allData = try sink.allData catch {
+ val sparkAnswer = try if (lastOnly) sink.lastBatch else sink.allData catch {
case e: Exception =>
failTest("Exception while getting data from sink", e)
}
- QueryTest.sameRows(expectedAnswer, allData).foreach {
+ QueryTest.sameRows(expectedAnswer, sparkAnswer).foreach {
error => failTest(error)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index e2090b0a83..6809f26968 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -272,12 +272,12 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
}
test("initcap function") {
- val df = Seq(("ab", "a B")).toDF("l", "r")
+ val df = Seq(("ab", "a B", "sParK")).toDF("x", "y", "z")
checkAnswer(
- df.select(initcap($"l"), initcap($"r")), Row("Ab", "A B"))
+ df.select(initcap($"x"), initcap($"y"), initcap($"z")), Row("Ab", "A B", "Spark"))
checkAnswer(
- df.selectExpr("InitCap(l)", "InitCap(r)"), Row("Ab", "A B"))
+ df.selectExpr("InitCap(x)", "InitCap(y)", "InitCap(z)"), Row("Ab", "A B", "Spark"))
}
test("number format function") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index fd736718af..ec950332c5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -83,7 +83,8 @@ class UDFSuite extends QueryTest with SharedSQLContext {
val e = intercept[AnalysisException] {
df.selectExpr("a_function_that_does_not_exist()")
}
- assert(e.getMessage.contains("undefined function"))
+ assert(e.getMessage.contains("Undefined function"))
+ assert(e.getMessage.contains("a_function_that_does_not_exist"))
}
test("Simple UDF") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 0b1cb90186..352fd07d0e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -21,20 +21,22 @@ import java.util.HashMap
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
+import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.vectorized.AggregateHashMap
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.types.{IntegerType, LongType, StructType}
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.hash.Murmur3_x86_32
import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.util.Benchmark
/**
- * Benchmark to measure whole stage codegen performance.
- * To run this:
- * build/sbt "sql/test-only *BenchmarkWholeStageCodegen"
- */
+ * Benchmark to measure whole stage codegen performance.
+ * To run this:
+ * build/sbt "sql/test-only *BenchmarkWholeStageCodegen"
+ */
class BenchmarkWholeStageCodegen extends SparkFunSuite {
lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark")
.set("spark.sql.shuffle.partitions", "1")
@@ -84,6 +86,31 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
*/
}
+ ignore("range/sample/sum") {
+ val N = 500 << 20
+ runBenchmark("range/sample/sum", N) {
+ sqlContext.range(N).sample(true, 0.01).groupBy().sum().collect()
+ }
+ /*
+ Westmere E56xx/L56xx/X56xx (Nehalem-C)
+ range/sample/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ range/sample/sum codegen=false 53888 / 56592 9.7 102.8 1.0X
+ range/sample/sum codegen=true 41614 / 42607 12.6 79.4 1.3X
+ */
+
+ runBenchmark("range/sample/sum", N) {
+ sqlContext.range(N).sample(false, 0.01).groupBy().sum().collect()
+ }
+ /*
+ Westmere E56xx/L56xx/X56xx (Nehalem-C)
+ range/sample/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ range/sample/sum codegen=false 12982 / 13384 40.4 24.8 1.0X
+ range/sample/sum codegen=true 7074 / 7383 74.1 13.5 1.8X
+ */
+ }
+
ignore("stat functions") {
val N = 100L << 20
@@ -140,20 +167,35 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
ignore("broadcast hash join") {
- val N = 100 << 20
+ val N = 20 << 20
val M = 1 << 16
val dim = broadcast(sqlContext.range(M).selectExpr("id as k", "cast(id as string) as v"))
runBenchmark("Join w long", N) {
- sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k")).count()
+ sqlContext.range(N).join(dim, (col("id") % M) === col("k")).count()
}
/*
+ Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- Join w long codegen=false 5744 / 5814 18.3 54.8 1.0X
- Join w long codegen=true 735 / 853 142.7 7.0 7.8X
+ Join w long codegen=false 3002 / 3262 7.0 143.2 1.0X
+ Join w long codegen=true 321 / 371 65.3 15.3 9.3X
+ */
+
+ runBenchmark("Join w long duplicated", N) {
+ val dim = broadcast(sqlContext.range(M).selectExpr("cast(id/10 as long) as k"))
+ sqlContext.range(N).join(dim, (col("id") % M) === col("k")).count()
+ }
+
+ /**
+ Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ Join w long duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ Join w long duplicated codegen=false 3446 / 3478 6.1 164.3 1.0X
+ Join w long duplicated codegen=true 322 / 351 65.2 15.3 10.7X
*/
val dim2 = broadcast(sqlContext.range(M)
@@ -161,16 +203,17 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
runBenchmark("Join w 2 ints", N) {
sqlContext.range(N).join(dim2,
- (col("id") bitwiseAND M).cast(IntegerType) === col("k1")
- && (col("id") bitwiseAND M).cast(IntegerType) === col("k2")).count()
+ (col("id") % M).cast(IntegerType) === col("k1")
+ && (col("id") % M).cast(IntegerType) === col("k2")).count()
}
/**
+ Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Join w 2 ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- Join w 2 ints codegen=false 7159 / 7224 14.6 68.3 1.0X
- Join w 2 ints codegen=true 1135 / 1197 92.4 10.8 6.3X
+ Join w 2 ints codegen=false 4426 / 4501 4.7 211.1 1.0X
+ Join w 2 ints codegen=true 791 / 818 26.5 37.7 5.6X
*/
val dim3 = broadcast(sqlContext.range(M)
@@ -178,39 +221,60 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
runBenchmark("Join w 2 longs", N) {
sqlContext.range(N).join(dim3,
- (col("id") bitwiseAND M) === col("k1") && (col("id") bitwiseAND M) === col("k2"))
+ (col("id") % M) === col("k1") && (col("id") % M) === col("k2"))
.count()
}
/**
+ Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
- Join w 2 longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ Join w 2 longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- Join w 2 longs codegen=false 7877 / 8358 13.3 75.1 1.0X
- Join w 2 longs codegen=true 3877 / 3937 27.0 37.0 2.0X
+ Join w 2 longs codegen=false 5905 / 6123 3.6 281.6 1.0X
+ Join w 2 longs codegen=true 2230 / 2529 9.4 106.3 2.6X
*/
+
+ val dim4 = broadcast(sqlContext.range(M)
+ .selectExpr("cast(id/10 as long) as k1", "cast(id/10 as long) as k2"))
+
+ runBenchmark("Join w 2 longs duplicated", N) {
+ sqlContext.range(N).join(dim4,
+ (col("id") bitwiseAND M) === col("k1") && (col("id") bitwiseAND M) === col("k2"))
+ .count()
+ }
+
+ /**
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ Join w 2 longs duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ Join w 2 longs duplicated codegen=false 6420 / 6587 3.3 306.1 1.0X
+ Join w 2 longs duplicated codegen=true 2080 / 2139 10.1 99.2 3.1X
+ */
+
runBenchmark("outer join w long", N) {
- sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "left").count()
+ sqlContext.range(N).join(dim, (col("id") % M) === col("k"), "left").count()
}
/**
+ Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- outer join w long codegen=false 15280 / 16497 6.9 145.7 1.0X
- outer join w long codegen=true 769 / 796 136.3 7.3 19.9X
+ outer join w long codegen=false 3055 / 3189 6.9 145.7 1.0X
+ outer join w long codegen=true 261 / 276 80.5 12.4 11.7X
*/
runBenchmark("semi join w long", N) {
- sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "leftsemi").count()
+ sqlContext.range(N).join(dim, (col("id") % M) === col("k"), "leftsemi").count()
}
/**
+ Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
semi join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- semi join w long codegen=false 5804 / 5969 18.1 55.3 1.0X
- semi join w long codegen=true 814 / 934 128.8 7.8 7.1X
+ semi join w long codegen=false 1912 / 1990 11.0 91.2 1.0X
+ semi join w long codegen=true 237 / 244 88.3 11.3 8.1X
*/
}
@@ -259,11 +323,12 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
/**
+ Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
shuffle hash join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- shuffle hash join codegen=false 1168 / 1902 3.6 278.6 1.0X
- shuffle hash join codegen=true 850 / 1196 4.9 202.8 1.4X
+ shuffle hash join codegen=false 1101 / 1391 3.8 262.6 1.0X
+ shuffle hash join codegen=true 528 / 578 7.9 125.8 2.1X
*/
}
@@ -285,11 +350,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
ignore("hash and BytesToBytesMap") {
- val N = 10 << 20
+ val N = 20 << 20
val benchmark = new Benchmark("BytesToBytesMap", N)
- benchmark.addCase("hash") { iter =>
+ benchmark.addCase("UnsafeRowhash") { iter =>
var i = 0
val keyBytes = new Array[Byte](16)
val key = new UnsafeRow(1)
@@ -304,15 +369,34 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
}
+ benchmark.addCase("murmur3 hash") { iter =>
+ var i = 0
+ val keyBytes = new Array[Byte](16)
+ val key = new UnsafeRow(1)
+ key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ var p = 524283
+ var s = 0
+ while (i < N) {
+ var h = Murmur3_x86_32.hashLong(i, 42)
+ key.setInt(0, h)
+ s += h
+ i += 1
+ }
+ }
+
benchmark.addCase("fast hash") { iter =>
var i = 0
val keyBytes = new Array[Byte](16)
val key = new UnsafeRow(1)
key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ var p = 524283
var s = 0
while (i < N) {
- key.setInt(0, i % 1000)
- val h = Murmur3_x86_32.hashLong(i % 1000, 42)
+ var h = i % p
+ if (h < 0) {
+ h += p
+ }
+ key.setInt(0, h)
s += h
i += 1
}
@@ -411,6 +495,42 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
}
}
+ Seq(false, true).foreach { optimized =>
+ benchmark.addCase(s"LongToUnsafeRowMap (opt=$optimized)") { iter =>
+ var i = 0
+ val valueBytes = new Array[Byte](16)
+ val value = new UnsafeRow(1)
+ value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ value.setInt(0, 555)
+ val taskMemoryManager = new TaskMemoryManager(
+ new StaticMemoryManager(
+ new SparkConf().set("spark.memory.offHeap.enabled", "false"),
+ Long.MaxValue,
+ Long.MaxValue,
+ 1),
+ 0)
+ val map = new LongToUnsafeRowMap(taskMemoryManager, 64)
+ while (i < 65536) {
+ value.setInt(0, i)
+ val key = i % 100000
+ map.append(key, value)
+ i += 1
+ }
+ if (optimized) {
+ map.optimize()
+ }
+ var s = 0
+ i = 0
+ while (i < N) {
+ val key = i % 100000
+ if (map.getValue(key, value) != null) {
+ s += 1
+ }
+ i += 1
+ }
+ }
+ }
+
Seq("off", "on").foreach { heap =>
benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter =>
val taskMemoryManager = new TaskMemoryManager(
@@ -429,35 +549,70 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
val value = new UnsafeRow(1)
value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
var i = 0
- while (i < N) {
+ val numKeys = 65536
+ while (i < numKeys) {
key.setInt(0, i % 65536)
val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
Murmur3_x86_32.hashLong(i % 65536, 42))
- if (loc.isDefined) {
- value.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
- value.setInt(0, value.getInt(0) + 1)
- i += 1
- } else {
- loc.putNewKey(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
+ if (!loc.isDefined) {
+ loc.append(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
value.getBaseObject, value.getBaseOffset, value.getSizeInBytes)
}
+ i += 1
+ }
+ i = 0
+ var s = 0
+ while (i < N) {
+ key.setInt(0, i % 100000)
+ val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
+ Murmur3_x86_32.hashLong(i % 100000, 42))
+ if (loc.isDefined) {
+ s += 1
+ }
+ i += 1
+ }
+ }
+ }
+
+ benchmark.addCase("Aggregate HashMap") { iter =>
+ var i = 0
+ val numKeys = 65536
+ val schema = new StructType()
+ .add("key", LongType)
+ .add("value", LongType)
+ val map = new AggregateHashMap(schema)
+ while (i < numKeys) {
+ val row = map.findOrInsert(i.toLong)
+ row.setLong(1, row.getLong(1) + 1)
+ i += 1
+ }
+ var s = 0
+ i = 0
+ while (i < N) {
+ if (map.find(i % 100000) != -1) {
+ s += 1
}
+ i += 1
}
}
/**
- Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- hash 651 / 678 80.0 12.5 1.0X
- fast hash 336 / 343 155.9 6.4 1.9X
- arrayEqual 417 / 428 125.0 8.0 1.6X
- Java HashMap (Long) 145 / 168 72.2 13.8 0.8X
- Java HashMap (two ints) 157 / 164 66.8 15.0 0.8X
- Java HashMap (UnsafeRow) 538 / 573 19.5 51.3 0.2X
- BytesToBytesMap (off Heap) 2594 / 2664 20.2 49.5 0.2X
- BytesToBytesMap (on Heap) 2693 / 2989 19.5 51.4 0.2X
- */
+ UnsafeRow hash 267 / 284 78.4 12.8 1.0X
+ murmur3 hash 102 / 129 205.5 4.9 2.6X
+ fast hash 79 / 96 263.8 3.8 3.4X
+ arrayEqual 164 / 172 128.2 7.8 1.6X
+ Java HashMap (Long) 321 / 399 65.4 15.3 0.8X
+ Java HashMap (two ints) 328 / 363 63.9 15.7 0.8X
+ Java HashMap (UnsafeRow) 1140 / 1200 18.4 54.3 0.2X
+ LongToUnsafeRowMap (opt=false) 378 / 400 55.5 18.0 0.7X
+ LongToUnsafeRowMap (opt=true) 144 / 152 145.2 6.9 1.9X
+ BytesToBytesMap (off Heap) 1300 / 1616 16.1 62.0 0.2X
+ BytesToBytesMap (on Heap) 1165 / 1202 18.0 55.5 0.2X
+ Aggregate HashMap 121 / 131 173.3 5.8 2.2X
+ */
benchmark.run()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
index 9680f3a008..17f2343cf9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
@@ -38,8 +38,8 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
test("compatible BroadcastMode") {
val mode1 = IdentityBroadcastMode
- val mode2 = HashedRelationBroadcastMode(true, Literal(1) :: Nil, Seq())
- val mode3 = HashedRelationBroadcastMode(false, Literal("s") :: Nil, Seq())
+ val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil)
+ val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil)
assert(mode1.compatibleWith(mode1))
assert(!mode1.compatibleWith(mode2))
@@ -56,10 +56,10 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
assert(plan sameResult plan)
val exchange1 = BroadcastExchange(IdentityBroadcastMode, plan)
- val hashMode = HashedRelationBroadcastMode(true, output, plan.output)
+ val hashMode = HashedRelationBroadcastMode(output)
val exchange2 = BroadcastExchange(hashMode, plan)
val hashMode2 =
- HashedRelationBroadcastMode(true, Alias(output.head, "id2")() :: Nil, plan.output)
+ HashedRelationBroadcastMode(Alias(output.head, "id2")() :: Nil)
val exchange3 = BroadcastExchange(hashMode2, plan)
val exchange4 = ReusedExchange(output, exchange3)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index ed0d3f56e5..38318740a5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -231,10 +231,8 @@ object SparkPlanTest {
}
private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = {
- // A very simple resolver to make writing tests easier. In contrast to the real resolver
- // this is always case sensitive and does not try to handle scoping or complex type resolution.
- val resolvedPlan = sqlContext.sessionState.prepareForExecution.execute(
- outputPlan transform {
+ val execution = new QueryExecution(sqlContext, null) {
+ override lazy val sparkPlan: SparkPlan = outputPlan transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
plan transformExpressions {
@@ -243,8 +241,8 @@ object SparkPlanTest {
sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
}
}
- )
- resolvedPlan.executeCollectPublic().toSeq
+ }
+ execution.executedPlan.executeCollectPublic().toSeq
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index 4dc7d3461c..c1555114e8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution
+import java.util.Properties
+
import scala.collection.mutable
import scala.util.{Random, Try}
import scala.util.control.NonFatal
@@ -71,6 +73,7 @@ class UnsafeFixedWidthAggregationMapSuite
taskAttemptId = Random.nextInt(10000),
attemptNumber = 0,
taskMemoryManager = taskMemoryManager,
+ localProperties = new Properties,
metricsSystem = null))
try {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
index 476d93fc2a..03d4be8ee5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution
+import java.util.Properties
+
import scala.util.Random
import org.apache.spark._
@@ -117,6 +119,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
taskAttemptId = 98456,
attemptNumber = 0,
taskMemoryManager = taskMemMgr,
+ localProperties = new Properties,
metricsSystem = null))
val sorter = new UnsafeKVExternalSorter(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
index 1f3779373b..01687877ee 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File}
+import java.util.Properties
import org.apache.spark._
import org.apache.spark.memory.TaskMemoryManager
@@ -113,7 +114,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
}
val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0)
val taskContext = new TaskContextImpl(
- 0, 0, 0, 0, taskMemoryManager, null, InternalAccumulator.create(sc))
+ 0, 0, 0, 0, taskMemoryManager, new Properties, null, InternalAccumulator.createAll(sc))
val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow](
taskContext,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 6d5be0b5dd..4474cfcf6e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -70,4 +70,33 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Sort]).isDefined)
assert(df.collect() === Array(Row(1), Row(2), Row(3)))
}
+
+ test("MapElements should be included in WholeStageCodegen") {
+ import testImplicits._
+
+ val ds = sqlContext.range(10).map(_.toString)
+ val plan = ds.queryExecution.executedPlan
+ assert(plan.find(p =>
+ p.isInstanceOf[WholeStageCodegen] &&
+ p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[MapElements]).isDefined)
+ assert(ds.collect() === 0.until(10).map(_.toString).toArray)
+ }
+
+ test("typed filter should be included in WholeStageCodegen") {
+ val ds = sqlContext.range(10).filter(_ % 2 == 0)
+ val plan = ds.queryExecution.executedPlan
+ assert(plan.find(p =>
+ p.isInstanceOf[WholeStageCodegen] &&
+ p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Filter]).isDefined)
+ assert(ds.collect() === Array(0, 2, 4, 6, 8))
+ }
+
+ test("back-to-back typed filter should be included in WholeStageCodegen") {
+ val ds = sqlContext.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0)
+ val plan = ds.queryExecution.executedPlan
+ assert(plan.find(p =>
+ p.isInstanceOf[WholeStageCodegen] &&
+ p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[SerializeFromObject]).isDefined)
+ assert(ds.collect() === Array(0, 6))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
index 9e04caf8ba..50c8745a28 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala
@@ -220,4 +220,14 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
assert(data.count() === 10)
assert(data.filter($"s" === "3").count() === 1)
}
+
+ test("SPARK-14138: Generated SpecificColumnarIterator can exceed JVM size limit for cached DF") {
+ val length1 = 3999
+ val columnTypes1 = List.fill(length1)(IntegerType)
+ val columnarIterator1 = GenerateColumnAccessor.generate(columnTypes1)
+
+ val length2 = 10000
+ val columnTypes2 = List.fill(length2)(IntegerType)
+ val columnarIterator2 = GenerateColumnAccessor.generate(columnTypes2)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala
index a33175aa60..d6ccaf9348 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala
@@ -17,15 +17,24 @@
package org.apache.spark.sql.execution.command
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending}
+import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.execution.SparkQl
-import org.apache.spark.sql.execution.datasources.BucketSpec
+import org.apache.spark.sql.catalyst.plans.logical.Project
+import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.types._
+// TODO: merge this with DDLSuite (SPARK-14441)
class DDLCommandSuite extends PlanTest {
- private val parser = new SparkQl
+ private val parser = SparkSqlParser
+
+ private def assertUnsupported(sql: String): Unit = {
+ val e = intercept[AnalysisException] {
+ parser.parsePlan(sql)
+ }
+ assert(e.getMessage.toLowerCase.contains("operation not allowed"))
+ }
test("create database") {
val sql =
@@ -40,7 +49,7 @@ class DDLCommandSuite extends PlanTest {
ifNotExists = true,
Some("/home/user/db"),
Some("database_comment"),
- Map("a" -> "a", "b" -> "b", "c" -> "c"))(sql)
+ Map("a" -> "a", "b" -> "b", "c" -> "c"))
comparePlans(parsed, expected)
}
@@ -66,39 +75,65 @@ class DDLCommandSuite extends PlanTest {
val expected1 = DropDatabase(
"database_name",
ifExists = true,
- restrict = true)(sql1)
+ cascade = false)
val expected2 = DropDatabase(
"database_name",
ifExists = true,
- restrict = false)(sql2)
+ cascade = true)
val expected3 = DropDatabase(
"database_name",
- ifExists = true,
- restrict = true)(sql3)
+ ifExists = false,
+ cascade = false)
val expected4 = DropDatabase(
"database_name",
- ifExists = true,
- restrict = false)(sql4)
- val expected5 = DropDatabase(
- "database_name",
- ifExists = true,
- restrict = true)(sql5)
- val expected6 = DropDatabase(
- "database_name",
ifExists = false,
- restrict = true)(sql6)
- val expected7 = DropDatabase(
+ cascade = true)
+
+ comparePlans(parsed1, expected1)
+ comparePlans(parsed2, expected2)
+ comparePlans(parsed3, expected1)
+ comparePlans(parsed4, expected2)
+ comparePlans(parsed5, expected1)
+ comparePlans(parsed6, expected3)
+ comparePlans(parsed7, expected4)
+ }
+
+ test("alter database set dbproperties") {
+ // ALTER (DATABASE|SCHEMA) database_name SET DBPROPERTIES (property_name=property_value, ...)
+ val sql1 = "ALTER DATABASE database_name SET DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')"
+ val sql2 = "ALTER SCHEMA database_name SET DBPROPERTIES ('a'='a')"
+
+ val parsed1 = parser.parsePlan(sql1)
+ val parsed2 = parser.parsePlan(sql2)
+
+ val expected1 = AlterDatabaseProperties(
"database_name",
- ifExists = false,
- restrict = false)(sql7)
+ Map("a" -> "a", "b" -> "b", "c" -> "c"))
+ val expected2 = AlterDatabaseProperties(
+ "database_name",
+ Map("a" -> "a"))
+
+ comparePlans(parsed1, expected1)
+ comparePlans(parsed2, expected2)
+ }
+
+ test("describe database") {
+ // DESCRIBE DATABASE [EXTENDED] db_name;
+ val sql1 = "DESCRIBE DATABASE EXTENDED db_name"
+ val sql2 = "DESCRIBE DATABASE db_name"
+
+ val parsed1 = parser.parsePlan(sql1)
+ val parsed2 = parser.parsePlan(sql2)
+
+ val expected1 = DescribeDatabase(
+ "db_name",
+ extended = true)
+ val expected2 = DescribeDatabase(
+ "db_name",
+ extended = false)
comparePlans(parsed1, expected1)
comparePlans(parsed2, expected2)
- comparePlans(parsed3, expected3)
- comparePlans(parsed4, expected4)
- comparePlans(parsed5, expected5)
- comparePlans(parsed6, expected6)
- comparePlans(parsed7, expected7)
}
test("create function") {
@@ -117,46 +152,115 @@ class DDLCommandSuite extends PlanTest {
val parsed1 = parser.parsePlan(sql1)
val parsed2 = parser.parsePlan(sql2)
val expected1 = CreateFunction(
+ None,
"helloworld",
"com.matthewrathbone.example.SimpleUDFExample",
Seq(("jar", "/path/to/jar1"), ("jar", "/path/to/jar2")),
- isTemp = true)(sql1)
+ isTemp = true)
val expected2 = CreateFunction(
- "hello.world",
+ Some("hello"),
+ "world",
"com.matthewrathbone.example.SimpleUDFExample",
Seq(("archive", "/path/to/archive"), ("file", "/path/to/file")),
- isTemp = false)(sql2)
+ isTemp = false)
comparePlans(parsed1, expected1)
comparePlans(parsed2, expected2)
}
- test("alter table: rename table") {
- val sql = "ALTER TABLE table_name RENAME TO new_table_name"
- val parsed = parser.parsePlan(sql)
- val expected = AlterTableRename(
- TableIdentifier("table_name", None),
- TableIdentifier("new_table_name", None))(sql)
- comparePlans(parsed, expected)
- }
+ test("drop function") {
+ val sql1 = "DROP TEMPORARY FUNCTION helloworld"
+ val sql2 = "DROP TEMPORARY FUNCTION IF EXISTS helloworld"
+ val sql3 = "DROP FUNCTION hello.world"
+ val sql4 = "DROP FUNCTION IF EXISTS hello.world"
- test("alter table: alter table properties") {
- val sql1 = "ALTER TABLE table_name SET TBLPROPERTIES ('test' = 'test', " +
- "'comment' = 'new_comment')"
- val sql2 = "ALTER TABLE table_name UNSET TBLPROPERTIES ('comment', 'test')"
- val sql3 = "ALTER TABLE table_name UNSET TBLPROPERTIES IF EXISTS ('comment', 'test')"
val parsed1 = parser.parsePlan(sql1)
val parsed2 = parser.parsePlan(sql2)
val parsed3 = parser.parsePlan(sql3)
- val tableIdent = TableIdentifier("table_name", None)
- val expected1 = AlterTableSetProperties(
- tableIdent, Map("test" -> "test", "comment" -> "new_comment"))(sql1)
- val expected2 = AlterTableUnsetProperties(
- tableIdent, Map("comment" -> null, "test" -> null), ifExists = false)(sql2)
- val expected3 = AlterTableUnsetProperties(
- tableIdent, Map("comment" -> null, "test" -> null), ifExists = true)(sql3)
+ val parsed4 = parser.parsePlan(sql4)
+
+ val expected1 = DropFunction(
+ None,
+ "helloworld",
+ ifExists = false,
+ isTemp = true)
+ val expected2 = DropFunction(
+ None,
+ "helloworld",
+ ifExists = true,
+ isTemp = true)
+ val expected3 = DropFunction(
+ Some("hello"),
+ "world",
+ ifExists = false,
+ isTemp = false)
+ val expected4 = DropFunction(
+ Some("hello"),
+ "world",
+ ifExists = true,
+ isTemp = false)
+
comparePlans(parsed1, expected1)
comparePlans(parsed2, expected2)
comparePlans(parsed3, expected3)
+ comparePlans(parsed4, expected4)
+ }
+
+ // ALTER TABLE table_name RENAME TO new_table_name;
+ // ALTER VIEW view_name RENAME TO new_view_name;
+ test("alter table/view: rename table/view") {
+ val sql_table = "ALTER TABLE table_name RENAME TO new_table_name"
+ val sql_view = sql_table.replace("TABLE", "VIEW")
+ val parsed_table = parser.parsePlan(sql_table)
+ val parsed_view = parser.parsePlan(sql_view)
+ val expected_table = AlterTableRename(
+ TableIdentifier("table_name", None),
+ TableIdentifier("new_table_name", None),
+ isView = false)
+ val expected_view = AlterTableRename(
+ TableIdentifier("table_name", None),
+ TableIdentifier("new_table_name", None),
+ isView = true)
+ comparePlans(parsed_table, expected_table)
+ comparePlans(parsed_view, expected_view)
+ }
+
+ // ALTER TABLE table_name SET TBLPROPERTIES ('comment' = new_comment);
+ // ALTER TABLE table_name UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key');
+ // ALTER VIEW view_name SET TBLPROPERTIES ('comment' = new_comment);
+ // ALTER VIEW view_name UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key');
+ test("alter table/view: alter table/view properties") {
+ val sql1_table = "ALTER TABLE table_name SET TBLPROPERTIES ('test' = 'test', " +
+ "'comment' = 'new_comment')"
+ val sql2_table = "ALTER TABLE table_name UNSET TBLPROPERTIES ('comment', 'test')"
+ val sql3_table = "ALTER TABLE table_name UNSET TBLPROPERTIES IF EXISTS ('comment', 'test')"
+ val sql1_view = sql1_table.replace("TABLE", "VIEW")
+ val sql2_view = sql2_table.replace("TABLE", "VIEW")
+ val sql3_view = sql3_table.replace("TABLE", "VIEW")
+
+ val parsed1_table = parser.parsePlan(sql1_table)
+ val parsed2_table = parser.parsePlan(sql2_table)
+ val parsed3_table = parser.parsePlan(sql3_table)
+ val parsed1_view = parser.parsePlan(sql1_view)
+ val parsed2_view = parser.parsePlan(sql2_view)
+ val parsed3_view = parser.parsePlan(sql3_view)
+
+ val tableIdent = TableIdentifier("table_name", None)
+ val expected1_table = AlterTableSetProperties(
+ tableIdent, Map("test" -> "test", "comment" -> "new_comment"), isView = false)
+ val expected2_table = AlterTableUnsetProperties(
+ tableIdent, Seq("comment", "test"), ifExists = false, isView = false)
+ val expected3_table = AlterTableUnsetProperties(
+ tableIdent, Seq("comment", "test"), ifExists = true, isView = false)
+ val expected1_view = expected1_table.copy(isView = true)
+ val expected2_view = expected2_table.copy(isView = true)
+ val expected3_view = expected3_table.copy(isView = true)
+
+ comparePlans(parsed1_table, expected1_table)
+ comparePlans(parsed2_table, expected2_table)
+ comparePlans(parsed3_table, expected3_table)
+ comparePlans(parsed1_view, expected1_view)
+ comparePlans(parsed2_view, expected2_view)
+ comparePlans(parsed3_view, expected3_view)
}
test("alter table: SerDe properties") {
@@ -189,24 +293,24 @@ class DDLCommandSuite extends PlanTest {
val parsed5 = parser.parsePlan(sql5)
val tableIdent = TableIdentifier("table_name", None)
val expected1 = AlterTableSerDeProperties(
- tableIdent, Some("org.apache.class"), None, None)(sql1)
+ tableIdent, Some("org.apache.class"), None, None)
val expected2 = AlterTableSerDeProperties(
tableIdent,
Some("org.apache.class"),
Some(Map("columns" -> "foo,bar", "field.delim" -> ",")),
- None)(sql2)
+ None)
val expected3 = AlterTableSerDeProperties(
- tableIdent, None, Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), None)(sql3)
+ tableIdent, None, Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), None)
val expected4 = AlterTableSerDeProperties(
tableIdent,
Some("org.apache.class"),
Some(Map("columns" -> "foo,bar", "field.delim" -> ",")),
- Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us")))(sql4)
+ Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us")))
val expected5 = AlterTableSerDeProperties(
tableIdent,
None,
Some(Map("columns" -> "foo,bar", "field.delim" -> ",")),
- Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us")))(sql5)
+ Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us")))
comparePlans(parsed1, expected1)
comparePlans(parsed2, expected2)
comparePlans(parsed3, expected3)
@@ -214,118 +318,56 @@ class DDLCommandSuite extends PlanTest {
comparePlans(parsed5, expected5)
}
- test("alter table: storage properties") {
- val sql1 = "ALTER TABLE table_name CLUSTERED BY (dt, country) INTO 10 BUCKETS"
- val sql2 = "ALTER TABLE table_name CLUSTERED BY (dt, country) SORTED BY " +
- "(dt, country DESC) INTO 10 BUCKETS"
- val sql3 = "ALTER TABLE table_name NOT CLUSTERED"
- val sql4 = "ALTER TABLE table_name NOT SORTED"
- val parsed1 = parser.parsePlan(sql1)
- val parsed2 = parser.parsePlan(sql2)
- val parsed3 = parser.parsePlan(sql3)
- val parsed4 = parser.parsePlan(sql4)
- val tableIdent = TableIdentifier("table_name", None)
- val cols = List("dt", "country")
- // TODO: also test the sort directions once we keep track of that
- val expected1 = AlterTableStorageProperties(
- tableIdent, BucketSpec(10, cols, Nil))(sql1)
- val expected2 = AlterTableStorageProperties(
- tableIdent, BucketSpec(10, cols, cols))(sql2)
- val expected3 = AlterTableNotClustered(tableIdent)(sql3)
- val expected4 = AlterTableNotSorted(tableIdent)(sql4)
- comparePlans(parsed1, expected1)
- comparePlans(parsed2, expected2)
- comparePlans(parsed3, expected3)
- comparePlans(parsed4, expected4)
- }
-
- test("alter table: skewed") {
+ // ALTER TABLE table_name ADD [IF NOT EXISTS] PARTITION partition_spec
+ // [LOCATION 'location1'] partition_spec [LOCATION 'location2'] ...;
+ test("alter table: add partition") {
val sql1 =
"""
- |ALTER TABLE table_name SKEWED BY (dt, country) ON
- |(('2008-08-08', 'us'), ('2009-09-09', 'uk'), ('2010-10-10', 'cn')) STORED AS DIRECTORIES
- """.stripMargin
- val sql2 =
- """
- |ALTER TABLE table_name SKEWED BY (dt, country) ON
- |('2008-08-08', 'us') STORED AS DIRECTORIES
- """.stripMargin
- val sql3 =
- """
- |ALTER TABLE table_name SKEWED BY (dt, country) ON
- |(('2008-08-08', 'us'), ('2009-09-09', 'uk'))
+ |ALTER TABLE table_name ADD IF NOT EXISTS PARTITION
+ |(dt='2008-08-08', country='us') LOCATION 'location1' PARTITION
+ |(dt='2009-09-09', country='uk')
""".stripMargin
- val sql4 = "ALTER TABLE table_name NOT SKEWED"
- val sql5 = "ALTER TABLE table_name NOT STORED AS DIRECTORIES"
+ val sql2 = "ALTER TABLE table_name ADD PARTITION (dt='2008-08-08') LOCATION 'loc'"
+
val parsed1 = parser.parsePlan(sql1)
val parsed2 = parser.parsePlan(sql2)
- val parsed3 = parser.parsePlan(sql3)
- val parsed4 = parser.parsePlan(sql4)
- val parsed5 = parser.parsePlan(sql5)
- val tableIdent = TableIdentifier("table_name", None)
- val expected1 = AlterTableSkewed(
- tableIdent,
- Seq("dt", "country"),
- Seq(List("2008-08-08", "us"), List("2009-09-09", "uk"), List("2010-10-10", "cn")),
- storedAsDirs = true)(sql1)
- val expected2 = AlterTableSkewed(
- tableIdent,
- Seq("dt", "country"),
- Seq(List("2008-08-08", "us")),
- storedAsDirs = true)(sql2)
- val expected3 = AlterTableSkewed(
- tableIdent,
- Seq("dt", "country"),
- Seq(List("2008-08-08", "us"), List("2009-09-09", "uk")),
- storedAsDirs = false)(sql3)
- val expected4 = AlterTableNotSkewed(tableIdent)(sql4)
- val expected5 = AlterTableNotStoredAsDirs(tableIdent)(sql5)
+
+ val expected1 = AlterTableAddPartition(
+ TableIdentifier("table_name", None),
+ Seq(
+ (Map("dt" -> "2008-08-08", "country" -> "us"), Some("location1")),
+ (Map("dt" -> "2009-09-09", "country" -> "uk"), None)),
+ ifNotExists = true)
+ val expected2 = AlterTableAddPartition(
+ TableIdentifier("table_name", None),
+ Seq((Map("dt" -> "2008-08-08"), Some("loc"))),
+ ifNotExists = false)
+
comparePlans(parsed1, expected1)
comparePlans(parsed2, expected2)
- comparePlans(parsed3, expected3)
- comparePlans(parsed4, expected4)
- comparePlans(parsed5, expected5)
}
- test("alter table: skewed location") {
+ // ALTER VIEW view_name ADD [IF NOT EXISTS] PARTITION partition_spec PARTITION partition_spec ...;
+ test("alter view: add partition") {
val sql1 =
"""
- |ALTER TABLE table_name SET SKEWED LOCATION
- |('123'='location1', 'test'='location2')
+ |ALTER VIEW view_name ADD IF NOT EXISTS PARTITION
+ |(dt='2008-08-08', country='us') PARTITION
+ |(dt='2009-09-09', country='uk')
""".stripMargin
+ // different constant types in partitioning spec
val sql2 =
- """
- |ALTER TABLE table_name SET SKEWED LOCATION
- |(('2008-08-08', 'us')='location1', 'test'='location2')
- """.stripMargin
- val parsed1 = parser.parsePlan(sql1)
- val parsed2 = parser.parsePlan(sql2)
- val tableIdent = TableIdentifier("table_name", None)
- val expected1 = AlterTableSkewedLocation(
- tableIdent,
- Map("123" -> "location1", "test" -> "location2"))(sql1)
- val expected2 = AlterTableSkewedLocation(
- tableIdent,
- Map("2008-08-08" -> "location1", "us" -> "location1", "test" -> "location2"))(sql2)
- comparePlans(parsed1, expected1)
- comparePlans(parsed2, expected2)
- }
+ """
+ |ALTER VIEW view_name ADD PARTITION
+ |(col1=NULL, cOL2='f', col3=5, COL4=true)
+ """.stripMargin
- test("alter table: add partition") {
- val sql =
- """
- |ALTER TABLE table_name ADD IF NOT EXISTS PARTITION
- |(dt='2008-08-08', country='us') LOCATION 'location1' PARTITION
- |(dt='2009-09-09', country='uk')
- """.stripMargin
- val parsed = parser.parsePlan(sql)
- val expected = AlterTableAddPartition(
- TableIdentifier("table_name", None),
- Seq(
- (Map("dt" -> "2008-08-08", "country" -> "us"), Some("location1")),
- (Map("dt" -> "2009-09-09", "country" -> "uk"), None)),
- ifNotExists = true)(sql)
- comparePlans(parsed, expected)
+ intercept[ParseException] {
+ parser.parsePlan(sql1)
+ }
+ intercept[ParseException] {
+ parser.parsePlan(sql2)
+ }
}
test("alter table: rename partition") {
@@ -338,106 +380,87 @@ class DDLCommandSuite extends PlanTest {
val expected = AlterTableRenamePartition(
TableIdentifier("table_name", None),
Map("dt" -> "2008-08-08", "country" -> "us"),
- Map("dt" -> "2008-09-09", "country" -> "uk"))(sql)
+ Map("dt" -> "2008-09-09", "country" -> "uk"))
comparePlans(parsed, expected)
}
- test("alter table: exchange partition") {
- val sql =
+ test("alter table: exchange partition (not supported)") {
+ assertUnsupported(
"""
|ALTER TABLE table_name_1 EXCHANGE PARTITION
|(dt='2008-08-08', country='us') WITH TABLE table_name_2
- """.stripMargin
- val parsed = parser.parsePlan(sql)
- val expected = AlterTableExchangePartition(
- TableIdentifier("table_name_1", None),
- TableIdentifier("table_name_2", None),
- Map("dt" -> "2008-08-08", "country" -> "us"))(sql)
- comparePlans(parsed, expected)
+ """.stripMargin)
}
- test("alter table: drop partitions") {
- val sql1 =
+ // ALTER TABLE table_name DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE]
+ // ALTER VIEW table_name DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...]
+ test("alter table/view: drop partitions") {
+ val sql1_table =
"""
|ALTER TABLE table_name DROP IF EXISTS PARTITION
|(dt='2008-08-08', country='us'), PARTITION (dt='2009-09-09', country='uk')
""".stripMargin
- val sql2 =
+ val sql2_table =
"""
|ALTER TABLE table_name DROP PARTITION
|(dt='2008-08-08', country='us'), PARTITION (dt='2009-09-09', country='uk') PURGE
""".stripMargin
- val parsed1 = parser.parsePlan(sql1)
- val parsed2 = parser.parsePlan(sql2)
+ val sql1_view = sql1_table.replace("TABLE", "VIEW")
+ // Note: ALTER VIEW DROP PARTITION does not support PURGE
+ val sql2_view = sql2_table.replace("TABLE", "VIEW").replace("PURGE", "")
+
+ val parsed1_table = parser.parsePlan(sql1_table)
+ val e = intercept[ParseException] {
+ parser.parsePlan(sql2_table)
+ }
+ assert(e.getMessage.contains("Operation not allowed"))
+
+ intercept[ParseException] {
+ parser.parsePlan(sql1_view)
+ }
+ intercept[ParseException] {
+ parser.parsePlan(sql2_view)
+ }
+
val tableIdent = TableIdentifier("table_name", None)
- val expected1 = AlterTableDropPartition(
- tableIdent,
- Seq(
- Map("dt" -> "2008-08-08", "country" -> "us"),
- Map("dt" -> "2009-09-09", "country" -> "uk")),
- ifExists = true,
- purge = false)(sql1)
- val expected2 = AlterTableDropPartition(
+ val expected1_table = AlterTableDropPartition(
tableIdent,
Seq(
Map("dt" -> "2008-08-08", "country" -> "us"),
Map("dt" -> "2009-09-09", "country" -> "uk")),
- ifExists = false,
- purge = true)(sql2)
- comparePlans(parsed1, expected1)
- comparePlans(parsed2, expected2)
+ ifExists = true)
+
+ comparePlans(parsed1_table, expected1_table)
}
- test("alter table: archive partition") {
- val sql = "ALTER TABLE table_name ARCHIVE PARTITION (dt='2008-08-08', country='us')"
- val parsed = parser.parsePlan(sql)
- val expected = AlterTableArchivePartition(
- TableIdentifier("table_name", None),
- Map("dt" -> "2008-08-08", "country" -> "us"))(sql)
- comparePlans(parsed, expected)
+ test("alter table: archive partition (not supported)") {
+ assertUnsupported("ALTER TABLE table_name ARCHIVE PARTITION (dt='2008-08-08', country='us')")
}
- test("alter table: unarchive partition") {
- val sql = "ALTER TABLE table_name UNARCHIVE PARTITION (dt='2008-08-08', country='us')"
- val parsed = parser.parsePlan(sql)
- val expected = AlterTableUnarchivePartition(
- TableIdentifier("table_name", None),
- Map("dt" -> "2008-08-08", "country" -> "us"))(sql)
- comparePlans(parsed, expected)
+ test("alter table: unarchive partition (not supported)") {
+ assertUnsupported("ALTER TABLE table_name UNARCHIVE PARTITION (dt='2008-08-08', country='us')")
}
test("alter table: set file format") {
- val sql1 =
- """
- |ALTER TABLE table_name SET FILEFORMAT INPUTFORMAT 'test'
- |OUTPUTFORMAT 'test' SERDE 'test' INPUTDRIVER 'test' OUTPUTDRIVER 'test'
- """.stripMargin
- val sql2 = "ALTER TABLE table_name SET FILEFORMAT INPUTFORMAT 'test' " +
+ val sql1 = "ALTER TABLE table_name SET FILEFORMAT INPUTFORMAT 'test' " +
"OUTPUTFORMAT 'test' SERDE 'test'"
- val sql3 = "ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') " +
+ val sql2 = "ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') " +
"SET FILEFORMAT PARQUET"
val parsed1 = parser.parsePlan(sql1)
val parsed2 = parser.parsePlan(sql2)
- val parsed3 = parser.parsePlan(sql3)
val tableIdent = TableIdentifier("table_name", None)
val expected1 = AlterTableSetFileFormat(
tableIdent,
None,
- List("test", "test", "test", "test", "test"),
+ List("test", "test", "test"),
None)(sql1)
val expected2 = AlterTableSetFileFormat(
tableIdent,
- None,
- List("test", "test", "test"),
- None)(sql2)
- val expected3 = AlterTableSetFileFormat(
- tableIdent,
Some(Map("dt" -> "2008-08-08", "country" -> "us")),
Seq(),
- Some("PARQUET"))(sql3)
+ Some("PARQUET"))(sql2)
comparePlans(parsed1, expected1)
comparePlans(parsed2, expected2)
- comparePlans(parsed3, expected3)
}
test("alter table: set location") {
@@ -450,64 +473,33 @@ class DDLCommandSuite extends PlanTest {
val expected1 = AlterTableSetLocation(
tableIdent,
None,
- "new location")(sql1)
+ "new location")
val expected2 = AlterTableSetLocation(
tableIdent,
Some(Map("dt" -> "2008-08-08", "country" -> "us")),
- "new location")(sql2)
+ "new location")
comparePlans(parsed1, expected1)
comparePlans(parsed2, expected2)
}
- test("alter table: touch") {
- val sql1 = "ALTER TABLE table_name TOUCH"
- val sql2 = "ALTER TABLE table_name TOUCH PARTITION (dt='2008-08-08', country='us')"
- val parsed1 = parser.parsePlan(sql1)
- val parsed2 = parser.parsePlan(sql2)
- val tableIdent = TableIdentifier("table_name", None)
- val expected1 = AlterTableTouch(
- tableIdent,
- None)(sql1)
- val expected2 = AlterTableTouch(
- tableIdent,
- Some(Map("dt" -> "2008-08-08", "country" -> "us")))(sql2)
- comparePlans(parsed1, expected1)
- comparePlans(parsed2, expected2)
+ test("alter table: touch (not supported)") {
+ assertUnsupported("ALTER TABLE table_name TOUCH")
+ assertUnsupported("ALTER TABLE table_name TOUCH PARTITION (dt='2008-08-08', country='us')")
}
- test("alter table: compact") {
- val sql1 = "ALTER TABLE table_name COMPACT 'compaction_type'"
- val sql2 =
+ test("alter table: compact (not supported)") {
+ assertUnsupported("ALTER TABLE table_name COMPACT 'compaction_type'")
+ assertUnsupported(
"""
- |ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us')
- |COMPACT 'MAJOR'
- """.stripMargin
- val parsed1 = parser.parsePlan(sql1)
- val parsed2 = parser.parsePlan(sql2)
- val tableIdent = TableIdentifier("table_name", None)
- val expected1 = AlterTableCompact(
- tableIdent,
- None,
- "compaction_type")(sql1)
- val expected2 = AlterTableCompact(
- tableIdent,
- Some(Map("dt" -> "2008-08-08", "country" -> "us")),
- "MAJOR")(sql2)
- comparePlans(parsed1, expected1)
- comparePlans(parsed2, expected2)
+ |ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us')
+ |COMPACT 'MAJOR'
+ """.stripMargin)
}
- test("alter table: concatenate") {
- val sql1 = "ALTER TABLE table_name CONCATENATE"
- val sql2 = "ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') CONCATENATE"
- val parsed1 = parser.parsePlan(sql1)
- val parsed2 = parser.parsePlan(sql2)
- val tableIdent = TableIdentifier("table_name", None)
- val expected1 = AlterTableMerge(tableIdent, None)(sql1)
- val expected2 = AlterTableMerge(
- tableIdent, Some(Map("dt" -> "2008-08-08", "country" -> "us")))(sql2)
- comparePlans(parsed1, expected1)
- comparePlans(parsed2, expected2)
+ test("alter table: concatenate (not supported)") {
+ assertUnsupported("ALTER TABLE table_name CONCATENATE")
+ assertUnsupported(
+ "ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') CONCATENATE")
}
test("alter table: change column name/type/position/comment") {
@@ -598,4 +590,110 @@ class DDLCommandSuite extends PlanTest {
comparePlans(parsed2, expected2)
}
+ test("show databases") {
+ val sql1 = "SHOW DATABASES"
+ val sql2 = "SHOW DATABASES LIKE 'defau*'"
+ val parsed1 = parser.parsePlan(sql1)
+ val expected1 = ShowDatabasesCommand(None)
+ val parsed2 = parser.parsePlan(sql2)
+ val expected2 = ShowDatabasesCommand(Some("defau*"))
+ comparePlans(parsed1, expected1)
+ comparePlans(parsed2, expected2)
+ }
+
+ test("show tblproperties") {
+ val parsed1 = parser.parsePlan("SHOW TBLPROPERTIES tab1")
+ val expected1 = ShowTablePropertiesCommand(TableIdentifier("tab1", None), None)
+ val parsed2 = parser.parsePlan("SHOW TBLPROPERTIES tab1('propKey1')")
+ val expected2 = ShowTablePropertiesCommand(TableIdentifier("tab1", None), Some("propKey1"))
+ comparePlans(parsed1, expected1)
+ comparePlans(parsed2, expected2)
+ }
+
+ test("unsupported operations") {
+ intercept[ParseException] {
+ parser.parsePlan("DROP TABLE tab PURGE")
+ }
+ intercept[ParseException] {
+ parser.parsePlan("DROP TABLE tab FOR REPLICATION('eventid')")
+ }
+ intercept[ParseException] {
+ parser.parsePlan("CREATE VIEW testView AS SELECT id FROM tab")
+ }
+ intercept[ParseException] {
+ parser.parsePlan("ALTER VIEW testView AS SELECT id FROM tab")
+ }
+ intercept[ParseException] {
+ parser.parsePlan(
+ """
+ |CREATE EXTERNAL TABLE parquet_tab2(c1 INT, c2 STRING)
+ |TBLPROPERTIES('prop1Key '= "prop1Val", ' `prop2Key` '= "prop2Val")
+ """.stripMargin)
+ }
+ intercept[ParseException] {
+ parser.parsePlan(
+ """
+ |CREATE EXTERNAL TABLE oneToTenDef
+ |USING org.apache.spark.sql.sources
+ |OPTIONS (from '1', to '10')
+ """.stripMargin)
+ }
+ intercept[ParseException] {
+ parser.parsePlan("SELECT TRANSFORM (key, value) USING 'cat' AS (tKey, tValue) FROM testData")
+ }
+ }
+
+ test("SPARK-14383: DISTRIBUTE and UNSET as non-keywords") {
+ val sql = "SELECT distribute, unset FROM x"
+ val parsed = parser.parsePlan(sql)
+ assert(parsed.isInstanceOf[Project])
+ }
+
+ test("drop table") {
+ val tableName1 = "db.tab"
+ val tableName2 = "tab"
+
+ val parsed1 = parser.parsePlan(s"DROP TABLE $tableName1")
+ val parsed2 = parser.parsePlan(s"DROP TABLE IF EXISTS $tableName1")
+ val parsed3 = parser.parsePlan(s"DROP TABLE $tableName2")
+ val parsed4 = parser.parsePlan(s"DROP TABLE IF EXISTS $tableName2")
+
+ val expected1 =
+ DropTable(TableIdentifier("tab", Option("db")), ifExists = false, isView = false)
+ val expected2 =
+ DropTable(TableIdentifier("tab", Option("db")), ifExists = true, isView = false)
+ val expected3 =
+ DropTable(TableIdentifier("tab", None), ifExists = false, isView = false)
+ val expected4 =
+ DropTable(TableIdentifier("tab", None), ifExists = true, isView = false)
+
+ comparePlans(parsed1, expected1)
+ comparePlans(parsed2, expected2)
+ comparePlans(parsed3, expected3)
+ comparePlans(parsed4, expected4)
+ }
+
+ test("drop view") {
+ val viewName1 = "db.view"
+ val viewName2 = "view"
+
+ val parsed1 = parser.parsePlan(s"DROP VIEW $viewName1")
+ val parsed2 = parser.parsePlan(s"DROP VIEW IF EXISTS $viewName1")
+ val parsed3 = parser.parsePlan(s"DROP VIEW $viewName2")
+ val parsed4 = parser.parsePlan(s"DROP VIEW IF EXISTS $viewName2")
+
+ val expected1 =
+ DropTable(TableIdentifier("view", Option("db")), ifExists = false, isView = true)
+ val expected2 =
+ DropTable(TableIdentifier("view", Option("db")), ifExists = true, isView = true)
+ val expected3 =
+ DropTable(TableIdentifier("view", None), ifExists = false, isView = true)
+ val expected4 =
+ DropTable(TableIdentifier("view", None), ifExists = true, isView = true)
+
+ comparePlans(parsed1, expected1)
+ comparePlans(parsed2, expected2)
+ comparePlans(parsed3, expected3)
+ comparePlans(parsed4, expected4)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
new file mode 100644
index 0000000000..9ffffa0bdd
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
@@ -0,0 +1,719 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.command
+
+import java.io.File
+
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogStorageFormat}
+import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType}
+import org.apache.spark.sql.catalyst.catalog.{CatalogTablePartition, SessionCatalog}
+import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec
+import org.apache.spark.sql.test.SharedSQLContext
+
+class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
+ private val escapedIdentifier = "`(.+)`".r
+
+ override def afterEach(): Unit = {
+ try {
+ // drop all databases, tables and functions after each test
+ sqlContext.sessionState.catalog.reset()
+ } finally {
+ super.afterEach()
+ }
+ }
+
+ /**
+ * Strip backticks, if any, from the string.
+ */
+ private def cleanIdentifier(ident: String): String = {
+ ident match {
+ case escapedIdentifier(i) => i
+ case plainIdent => plainIdent
+ }
+ }
+
+ private def assertUnsupported(query: String): Unit = {
+ val e = intercept[AnalysisException] {
+ sql(query)
+ }
+ assert(e.getMessage.toLowerCase.contains("operation not allowed"))
+ }
+
+ private def maybeWrapException[T](expectException: Boolean)(body: => T): Unit = {
+ if (expectException) intercept[AnalysisException] { body } else body
+ }
+
+ private def createDatabase(catalog: SessionCatalog, name: String): Unit = {
+ catalog.createDatabase(CatalogDatabase(name, "", "", Map()), ignoreIfExists = false)
+ }
+
+ private def createTable(catalog: SessionCatalog, name: TableIdentifier): Unit = {
+ catalog.createTable(CatalogTable(
+ identifier = name,
+ tableType = CatalogTableType.EXTERNAL_TABLE,
+ storage = CatalogStorageFormat(None, None, None, None, Map()),
+ schema = Seq()), ignoreIfExists = false)
+ }
+
+ private def createTablePartition(
+ catalog: SessionCatalog,
+ spec: TablePartitionSpec,
+ tableName: TableIdentifier): Unit = {
+ val part = CatalogTablePartition(spec, CatalogStorageFormat(None, None, None, None, Map()))
+ catalog.createPartitions(tableName, Seq(part), ignoreIfExists = false)
+ }
+
+ test("Create/Drop Database") {
+ val catalog = sqlContext.sessionState.catalog
+
+ val databaseNames = Seq("db1", "`database`")
+
+ databaseNames.foreach { dbName =>
+ try {
+ val dbNameWithoutBackTicks = cleanIdentifier(dbName)
+
+ sql(s"CREATE DATABASE $dbName")
+ val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks)
+ assert(db1 == CatalogDatabase(
+ dbNameWithoutBackTicks,
+ "",
+ System.getProperty("java.io.tmpdir") + File.separator + s"$dbNameWithoutBackTicks.db",
+ Map.empty))
+ sql(s"DROP DATABASE $dbName CASCADE")
+ assert(!catalog.databaseExists(dbNameWithoutBackTicks))
+ } finally {
+ catalog.reset()
+ }
+ }
+ }
+
+ test("Create Database - database already exists") {
+ val catalog = sqlContext.sessionState.catalog
+ val databaseNames = Seq("db1", "`database`")
+
+ databaseNames.foreach { dbName =>
+ try {
+ val dbNameWithoutBackTicks = cleanIdentifier(dbName)
+ sql(s"CREATE DATABASE $dbName")
+ val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks)
+ assert(db1 == CatalogDatabase(
+ dbNameWithoutBackTicks,
+ "",
+ System.getProperty("java.io.tmpdir") + File.separator + s"$dbNameWithoutBackTicks.db",
+ Map.empty))
+
+ val message = intercept[AnalysisException] {
+ sql(s"CREATE DATABASE $dbName")
+ }.getMessage
+ assert(message.contains(s"Database '$dbNameWithoutBackTicks' already exists."))
+ } finally {
+ catalog.reset()
+ }
+ }
+ }
+
+ test("Alter/Describe Database") {
+ val catalog = sqlContext.sessionState.catalog
+ val databaseNames = Seq("db1", "`database`")
+
+ databaseNames.foreach { dbName =>
+ try {
+ val dbNameWithoutBackTicks = cleanIdentifier(dbName)
+ val location =
+ System.getProperty("java.io.tmpdir") + File.separator + s"$dbNameWithoutBackTicks.db"
+ sql(s"CREATE DATABASE $dbName")
+
+ checkAnswer(
+ sql(s"DESCRIBE DATABASE EXTENDED $dbName"),
+ Row("Database Name", dbNameWithoutBackTicks) ::
+ Row("Description", "") ::
+ Row("Location", location) ::
+ Row("Properties", "") :: Nil)
+
+ sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')")
+
+ checkAnswer(
+ sql(s"DESCRIBE DATABASE EXTENDED $dbName"),
+ Row("Database Name", dbNameWithoutBackTicks) ::
+ Row("Description", "") ::
+ Row("Location", location) ::
+ Row("Properties", "((a,a), (b,b), (c,c))") :: Nil)
+
+ sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('d'='d')")
+
+ checkAnswer(
+ sql(s"DESCRIBE DATABASE EXTENDED $dbName"),
+ Row("Database Name", dbNameWithoutBackTicks) ::
+ Row("Description", "") ::
+ Row("Location", location) ::
+ Row("Properties", "((a,a), (b,b), (c,c), (d,d))") :: Nil)
+ } finally {
+ catalog.reset()
+ }
+ }
+ }
+
+ test("Drop/Alter/Describe Database - database does not exists") {
+ val databaseNames = Seq("db1", "`database`")
+
+ databaseNames.foreach { dbName =>
+ val dbNameWithoutBackTicks = cleanIdentifier(dbName)
+ assert(!sqlContext.sessionState.catalog.databaseExists(dbNameWithoutBackTicks))
+
+ var message = intercept[AnalysisException] {
+ sql(s"DROP DATABASE $dbName")
+ }.getMessage
+ assert(message.contains(s"Database '$dbNameWithoutBackTicks' does not exist"))
+
+ message = intercept[AnalysisException] {
+ sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('d'='d')")
+ }.getMessage
+ assert(message.contains(s"Database '$dbNameWithoutBackTicks' does not exist"))
+
+ message = intercept[AnalysisException] {
+ sql(s"DESCRIBE DATABASE EXTENDED $dbName")
+ }.getMessage
+ assert(message.contains(s"Database '$dbNameWithoutBackTicks' does not exist"))
+
+ sql(s"DROP DATABASE IF EXISTS $dbName")
+ }
+ }
+
+ // TODO: test drop database in restrict mode
+
+ test("alter table: rename") {
+ val catalog = sqlContext.sessionState.catalog
+ val tableIdent1 = TableIdentifier("tab1", Some("dbx"))
+ val tableIdent2 = TableIdentifier("tab2", Some("dbx"))
+ createDatabase(catalog, "dbx")
+ createDatabase(catalog, "dby")
+ createTable(catalog, tableIdent1)
+ assert(catalog.listTables("dbx") == Seq(tableIdent1))
+ sql("ALTER TABLE dbx.tab1 RENAME TO dbx.tab2")
+ assert(catalog.listTables("dbx") == Seq(tableIdent2))
+ catalog.setCurrentDatabase("dbx")
+ // rename without explicitly specifying database
+ sql("ALTER TABLE tab2 RENAME TO tab1")
+ assert(catalog.listTables("dbx") == Seq(tableIdent1))
+ // table to rename does not exist
+ intercept[AnalysisException] {
+ sql("ALTER TABLE dbx.does_not_exist RENAME TO dbx.tab2")
+ }
+ // destination database is different
+ intercept[AnalysisException] {
+ sql("ALTER TABLE dbx.tab1 RENAME TO dby.tab2")
+ }
+ }
+
+ test("alter table: set location") {
+ testSetLocation(isDatasourceTable = false)
+ }
+
+ test("alter table: set location (datasource table)") {
+ testSetLocation(isDatasourceTable = true)
+ }
+
+ test("alter table: set properties") {
+ val catalog = sqlContext.sessionState.catalog
+ val tableIdent = TableIdentifier("tab1", Some("dbx"))
+ createDatabase(catalog, "dbx")
+ createTable(catalog, tableIdent)
+ assert(catalog.getTableMetadata(tableIdent).properties.isEmpty)
+ // set table properties
+ sql("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('andrew' = 'or14', 'kor' = 'bel')")
+ assert(catalog.getTableMetadata(tableIdent).properties ==
+ Map("andrew" -> "or14", "kor" -> "bel"))
+ // set table properties without explicitly specifying database
+ catalog.setCurrentDatabase("dbx")
+ sql("ALTER TABLE tab1 SET TBLPROPERTIES ('kor' = 'belle', 'kar' = 'bol')")
+ assert(catalog.getTableMetadata(tableIdent).properties ==
+ Map("andrew" -> "or14", "kor" -> "belle", "kar" -> "bol"))
+ // table to alter does not exist
+ intercept[AnalysisException] {
+ sql("ALTER TABLE does_not_exist SET TBLPROPERTIES ('winner' = 'loser')")
+ }
+ // throw exception for datasource tables
+ convertToDatasourceTable(catalog, tableIdent)
+ val e = intercept[AnalysisException] {
+ sql("ALTER TABLE tab1 SET TBLPROPERTIES ('sora' = 'bol')")
+ }
+ assert(e.getMessage.contains("datasource"))
+ }
+
+ test("alter table: unset properties") {
+ val catalog = sqlContext.sessionState.catalog
+ val tableIdent = TableIdentifier("tab1", Some("dbx"))
+ createDatabase(catalog, "dbx")
+ createTable(catalog, tableIdent)
+ // unset table properties
+ sql("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('j' = 'am', 'p' = 'an', 'c' = 'lan')")
+ sql("ALTER TABLE dbx.tab1 UNSET TBLPROPERTIES ('j')")
+ assert(catalog.getTableMetadata(tableIdent).properties == Map("p" -> "an", "c" -> "lan"))
+ // unset table properties without explicitly specifying database
+ catalog.setCurrentDatabase("dbx")
+ sql("ALTER TABLE tab1 UNSET TBLPROPERTIES ('p')")
+ assert(catalog.getTableMetadata(tableIdent).properties == Map("c" -> "lan"))
+ // table to alter does not exist
+ intercept[AnalysisException] {
+ sql("ALTER TABLE does_not_exist UNSET TBLPROPERTIES ('c' = 'lan')")
+ }
+ // property to unset does not exist
+ val e = intercept[AnalysisException] {
+ sql("ALTER TABLE tab1 UNSET TBLPROPERTIES ('c', 'xyz')")
+ }
+ assert(e.getMessage.contains("xyz"))
+ // property to unset does not exist, but "IF EXISTS" is specified
+ sql("ALTER TABLE tab1 UNSET TBLPROPERTIES IF EXISTS ('c', 'xyz')")
+ assert(catalog.getTableMetadata(tableIdent).properties.isEmpty)
+ // throw exception for datasource tables
+ convertToDatasourceTable(catalog, tableIdent)
+ val e1 = intercept[AnalysisException] {
+ sql("ALTER TABLE tab1 UNSET TBLPROPERTIES ('sora')")
+ }
+ assert(e1.getMessage.contains("datasource"))
+ }
+
+ test("alter table: set serde") {
+ testSetSerde(isDatasourceTable = false)
+ }
+
+ test("alter table: set serde (datasource table)") {
+ testSetSerde(isDatasourceTable = true)
+ }
+
+ test("alter table: bucketing is not supported") {
+ val catalog = sqlContext.sessionState.catalog
+ val tableIdent = TableIdentifier("tab1", Some("dbx"))
+ createDatabase(catalog, "dbx")
+ createTable(catalog, tableIdent)
+ assertUnsupported("ALTER TABLE dbx.tab1 CLUSTERED BY (blood, lemon, grape) INTO 11 BUCKETS")
+ assertUnsupported("ALTER TABLE dbx.tab1 CLUSTERED BY (fuji) SORTED BY (grape) INTO 5 BUCKETS")
+ assertUnsupported("ALTER TABLE dbx.tab1 NOT CLUSTERED")
+ assertUnsupported("ALTER TABLE dbx.tab1 NOT SORTED")
+ }
+
+ test("alter table: skew is not supported") {
+ val catalog = sqlContext.sessionState.catalog
+ val tableIdent = TableIdentifier("tab1", Some("dbx"))
+ createDatabase(catalog, "dbx")
+ createTable(catalog, tableIdent)
+ assertUnsupported("ALTER TABLE dbx.tab1 SKEWED BY (dt, country) ON " +
+ "(('2008-08-08', 'us'), ('2009-09-09', 'uk'), ('2010-10-10', 'cn'))")
+ assertUnsupported("ALTER TABLE dbx.tab1 SKEWED BY (dt, country) ON " +
+ "(('2008-08-08', 'us'), ('2009-09-09', 'uk')) STORED AS DIRECTORIES")
+ assertUnsupported("ALTER TABLE dbx.tab1 NOT SKEWED")
+ assertUnsupported("ALTER TABLE dbx.tab1 NOT STORED AS DIRECTORIES")
+ }
+
+ test("alter table: add partition") {
+ testAddPartitions(isDatasourceTable = false)
+ }
+
+ test("alter table: add partition (datasource table)") {
+ testAddPartitions(isDatasourceTable = true)
+ }
+
+ test("alter table: add partition is not supported for views") {
+ assertUnsupported("ALTER VIEW dbx.tab1 ADD IF NOT EXISTS PARTITION (b='2')")
+ }
+
+ test("alter table: drop partition") {
+ testDropPartitions(isDatasourceTable = false)
+ }
+
+ test("alter table: drop partition (datasource table)") {
+ testDropPartitions(isDatasourceTable = true)
+ }
+
+ test("alter table: drop partition is not supported for views") {
+ assertUnsupported("ALTER VIEW dbx.tab1 DROP IF EXISTS PARTITION (b='2')")
+ }
+
+ test("alter table: rename partition") {
+ val catalog = sqlContext.sessionState.catalog
+ val tableIdent = TableIdentifier("tab1", Some("dbx"))
+ val part1 = Map("a" -> "1")
+ val part2 = Map("b" -> "2")
+ val part3 = Map("c" -> "3")
+ createDatabase(catalog, "dbx")
+ createTable(catalog, tableIdent)
+ createTablePartition(catalog, part1, tableIdent)
+ createTablePartition(catalog, part2, tableIdent)
+ createTablePartition(catalog, part3, tableIdent)
+ assert(catalog.listPartitions(tableIdent).map(_.spec).toSet ==
+ Set(part1, part2, part3))
+ sql("ALTER TABLE dbx.tab1 PARTITION (a='1') RENAME TO PARTITION (a='100')")
+ sql("ALTER TABLE dbx.tab1 PARTITION (b='2') RENAME TO PARTITION (b='200')")
+ assert(catalog.listPartitions(tableIdent).map(_.spec).toSet ==
+ Set(Map("a" -> "100"), Map("b" -> "200"), part3))
+ // rename without explicitly specifying database
+ catalog.setCurrentDatabase("dbx")
+ sql("ALTER TABLE tab1 PARTITION (a='100') RENAME TO PARTITION (a='10')")
+ assert(catalog.listPartitions(tableIdent).map(_.spec).toSet ==
+ Set(Map("a" -> "10"), Map("b" -> "200"), part3))
+ // table to alter does not exist
+ intercept[AnalysisException] {
+ sql("ALTER TABLE does_not_exist PARTITION (c='3') RENAME TO PARTITION (c='333')")
+ }
+ // partition to rename does not exist
+ intercept[AnalysisException] {
+ sql("ALTER TABLE tab1 PARTITION (x='300') RENAME TO PARTITION (x='333')")
+ }
+ }
+
+ test("show tables") {
+ withTempTable("show1a", "show2b") {
+ sql(
+ """
+ |CREATE TEMPORARY TABLE show1a
+ |USING org.apache.spark.sql.sources.DDLScanSource
+ |OPTIONS (
+ | From '1',
+ | To '10',
+ | Table 'test1'
+ |
+ |)
+ """.stripMargin)
+ sql(
+ """
+ |CREATE TEMPORARY TABLE show2b
+ |USING org.apache.spark.sql.sources.DDLScanSource
+ |OPTIONS (
+ | From '1',
+ | To '10',
+ | Table 'test1'
+ |)
+ """.stripMargin)
+ checkAnswer(
+ sql("SHOW TABLES IN default 'show1*'"),
+ Row("show1a", true) :: Nil)
+
+ checkAnswer(
+ sql("SHOW TABLES IN default 'show1*|show2*'"),
+ Row("show1a", true) ::
+ Row("show2b", true) :: Nil)
+
+ checkAnswer(
+ sql("SHOW TABLES 'show1*|show2*'"),
+ Row("show1a", true) ::
+ Row("show2b", true) :: Nil)
+
+ assert(
+ sql("SHOW TABLES").count() >= 2)
+ assert(
+ sql("SHOW TABLES IN default").count() >= 2)
+ }
+ }
+
+ test("show databases") {
+ sql("CREATE DATABASE showdb1A")
+ sql("CREATE DATABASE showdb2B")
+
+ assert(
+ sql("SHOW DATABASES").count() >= 2)
+
+ checkAnswer(
+ sql("SHOW DATABASES LIKE '*db1A'"),
+ Row("showdb1A") :: Nil)
+
+ checkAnswer(
+ sql("SHOW DATABASES LIKE 'showdb1A'"),
+ Row("showdb1A") :: Nil)
+
+ checkAnswer(
+ sql("SHOW DATABASES LIKE '*db1A|*db2B'"),
+ Row("showdb1A") ::
+ Row("showdb2B") :: Nil)
+
+ checkAnswer(
+ sql("SHOW DATABASES LIKE 'non-existentdb'"),
+ Nil)
+ }
+
+ test("drop table - temporary table") {
+ val catalog = sqlContext.sessionState.catalog
+ sql(
+ """
+ |CREATE TEMPORARY TABLE tab1
+ |USING org.apache.spark.sql.sources.DDLScanSource
+ |OPTIONS (
+ | From '1',
+ | To '10',
+ | Table 'test1'
+ |)
+ """.stripMargin)
+ assert(catalog.listTables("default") == Seq(TableIdentifier("tab1")))
+ sql("DROP TABLE tab1")
+ assert(catalog.listTables("default") == Nil)
+ }
+
+ test("drop table") {
+ testDropTable(isDatasourceTable = false)
+ }
+
+ test("drop table - data source table") {
+ testDropTable(isDatasourceTable = true)
+ }
+
+ private def testDropTable(isDatasourceTable: Boolean): Unit = {
+ val catalog = sqlContext.sessionState.catalog
+ val tableIdent = TableIdentifier("tab1", Some("dbx"))
+ createDatabase(catalog, "dbx")
+ createTable(catalog, tableIdent)
+ if (isDatasourceTable) {
+ convertToDatasourceTable(catalog, tableIdent)
+ }
+ assert(catalog.listTables("dbx") == Seq(tableIdent))
+ sql("DROP TABLE dbx.tab1")
+ assert(catalog.listTables("dbx") == Nil)
+ sql("DROP TABLE IF EXISTS dbx.tab1")
+ // no exception will be thrown
+ sql("DROP TABLE dbx.tab1")
+ }
+
+ test("drop view in SQLContext") {
+ // SQLContext does not support create view. Log an error message, if tab1 does not exists
+ sql("DROP VIEW tab1")
+
+ val catalog = sqlContext.sessionState.catalog
+ val tableIdent = TableIdentifier("tab1", Some("dbx"))
+ createDatabase(catalog, "dbx")
+ createTable(catalog, tableIdent)
+ assert(catalog.listTables("dbx") == Seq(tableIdent))
+
+ val e = intercept[AnalysisException] {
+ sql("DROP VIEW dbx.tab1")
+ }
+ assert(
+ e.getMessage.contains("Cannot drop a table with DROP VIEW. Please use DROP TABLE instead"))
+ }
+
+ private def convertToDatasourceTable(
+ catalog: SessionCatalog,
+ tableIdent: TableIdentifier): Unit = {
+ catalog.alterTable(catalog.getTableMetadata(tableIdent).copy(
+ properties = Map("spark.sql.sources.provider" -> "csv")))
+ }
+
+ private def testSetLocation(isDatasourceTable: Boolean): Unit = {
+ val catalog = sqlContext.sessionState.catalog
+ val tableIdent = TableIdentifier("tab1", Some("dbx"))
+ val partSpec = Map("a" -> "1")
+ createDatabase(catalog, "dbx")
+ createTable(catalog, tableIdent)
+ createTablePartition(catalog, partSpec, tableIdent)
+ if (isDatasourceTable) {
+ convertToDatasourceTable(catalog, tableIdent)
+ }
+ assert(catalog.getTableMetadata(tableIdent).storage.locationUri.isEmpty)
+ assert(catalog.getTableMetadata(tableIdent).storage.serdeProperties.isEmpty)
+ assert(catalog.getPartition(tableIdent, partSpec).storage.locationUri.isEmpty)
+ assert(catalog.getPartition(tableIdent, partSpec).storage.serdeProperties.isEmpty)
+ // Verify that the location is set to the expected string
+ def verifyLocation(expected: String, spec: Option[TablePartitionSpec] = None): Unit = {
+ val storageFormat = spec
+ .map { s => catalog.getPartition(tableIdent, s).storage }
+ .getOrElse { catalog.getTableMetadata(tableIdent).storage }
+ if (isDatasourceTable) {
+ if (spec.isDefined) {
+ assert(storageFormat.serdeProperties.isEmpty)
+ assert(storageFormat.locationUri.isEmpty)
+ } else {
+ assert(storageFormat.serdeProperties.get("path") === Some(expected))
+ assert(storageFormat.locationUri === Some(expected))
+ }
+ } else {
+ assert(storageFormat.locationUri === Some(expected))
+ }
+ }
+ // set table location
+ sql("ALTER TABLE dbx.tab1 SET LOCATION '/path/to/your/lovely/heart'")
+ verifyLocation("/path/to/your/lovely/heart")
+ // set table partition location
+ maybeWrapException(isDatasourceTable) {
+ sql("ALTER TABLE dbx.tab1 PARTITION (a='1') SET LOCATION '/path/to/part/ways'")
+ }
+ verifyLocation("/path/to/part/ways", Some(partSpec))
+ // set table location without explicitly specifying database
+ catalog.setCurrentDatabase("dbx")
+ sql("ALTER TABLE tab1 SET LOCATION '/swanky/steak/place'")
+ verifyLocation("/swanky/steak/place")
+ // set table partition location without explicitly specifying database
+ maybeWrapException(isDatasourceTable) {
+ sql("ALTER TABLE tab1 PARTITION (a='1') SET LOCATION 'vienna'")
+ }
+ verifyLocation("vienna", Some(partSpec))
+ // table to alter does not exist
+ intercept[AnalysisException] {
+ sql("ALTER TABLE dbx.does_not_exist SET LOCATION '/mister/spark'")
+ }
+ // partition to alter does not exist
+ intercept[AnalysisException] {
+ sql("ALTER TABLE dbx.tab1 PARTITION (b='2') SET LOCATION '/mister/spark'")
+ }
+ }
+
+ private def testSetSerde(isDatasourceTable: Boolean): Unit = {
+ val catalog = sqlContext.sessionState.catalog
+ val tableIdent = TableIdentifier("tab1", Some("dbx"))
+ createDatabase(catalog, "dbx")
+ createTable(catalog, tableIdent)
+ if (isDatasourceTable) {
+ convertToDatasourceTable(catalog, tableIdent)
+ }
+ assert(catalog.getTableMetadata(tableIdent).storage.serde.isEmpty)
+ assert(catalog.getTableMetadata(tableIdent).storage.serdeProperties.isEmpty)
+ // set table serde and/or properties (should fail on datasource tables)
+ if (isDatasourceTable) {
+ val e1 = intercept[AnalysisException] {
+ sql("ALTER TABLE dbx.tab1 SET SERDE 'whatever'")
+ }
+ val e2 = intercept[AnalysisException] {
+ sql("ALTER TABLE dbx.tab1 SET SERDE 'org.apache.madoop' " +
+ "WITH SERDEPROPERTIES ('k' = 'v', 'kay' = 'vee')")
+ }
+ assert(e1.getMessage.contains("datasource"))
+ assert(e2.getMessage.contains("datasource"))
+ } else {
+ sql("ALTER TABLE dbx.tab1 SET SERDE 'org.apache.jadoop'")
+ assert(catalog.getTableMetadata(tableIdent).storage.serde == Some("org.apache.jadoop"))
+ assert(catalog.getTableMetadata(tableIdent).storage.serdeProperties.isEmpty)
+ sql("ALTER TABLE dbx.tab1 SET SERDE 'org.apache.madoop' " +
+ "WITH SERDEPROPERTIES ('k' = 'v', 'kay' = 'vee')")
+ assert(catalog.getTableMetadata(tableIdent).storage.serde == Some("org.apache.madoop"))
+ assert(catalog.getTableMetadata(tableIdent).storage.serdeProperties ==
+ Map("k" -> "v", "kay" -> "vee"))
+ }
+ // set serde properties only
+ sql("ALTER TABLE dbx.tab1 SET SERDEPROPERTIES ('k' = 'vvv', 'kay' = 'vee')")
+ assert(catalog.getTableMetadata(tableIdent).storage.serdeProperties ==
+ Map("k" -> "vvv", "kay" -> "vee"))
+ // set things without explicitly specifying database
+ catalog.setCurrentDatabase("dbx")
+ sql("ALTER TABLE tab1 SET SERDEPROPERTIES ('kay' = 'veee')")
+ assert(catalog.getTableMetadata(tableIdent).storage.serdeProperties ==
+ Map("k" -> "vvv", "kay" -> "veee"))
+ // table to alter does not exist
+ intercept[AnalysisException] {
+ sql("ALTER TABLE does_not_exist SET SERDEPROPERTIES ('x' = 'y')")
+ }
+ }
+
+ private def testAddPartitions(isDatasourceTable: Boolean): Unit = {
+ val catalog = sqlContext.sessionState.catalog
+ val tableIdent = TableIdentifier("tab1", Some("dbx"))
+ val part1 = Map("a" -> "1")
+ val part2 = Map("b" -> "2")
+ val part3 = Map("c" -> "3")
+ val part4 = Map("d" -> "4")
+ createDatabase(catalog, "dbx")
+ createTable(catalog, tableIdent)
+ createTablePartition(catalog, part1, tableIdent)
+ if (isDatasourceTable) {
+ convertToDatasourceTable(catalog, tableIdent)
+ }
+ assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1))
+ maybeWrapException(isDatasourceTable) {
+ sql("ALTER TABLE dbx.tab1 ADD IF NOT EXISTS " +
+ "PARTITION (b='2') LOCATION 'paris' PARTITION (c='3')")
+ }
+ if (!isDatasourceTable) {
+ assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3))
+ assert(catalog.getPartition(tableIdent, part1).storage.locationUri.isEmpty)
+ assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Some("paris"))
+ assert(catalog.getPartition(tableIdent, part3).storage.locationUri.isEmpty)
+ }
+ // add partitions without explicitly specifying database
+ catalog.setCurrentDatabase("dbx")
+ maybeWrapException(isDatasourceTable) {
+ sql("ALTER TABLE tab1 ADD IF NOT EXISTS PARTITION (d='4')")
+ }
+ if (!isDatasourceTable) {
+ assert(catalog.listPartitions(tableIdent).map(_.spec).toSet ==
+ Set(part1, part2, part3, part4))
+ }
+ // table to alter does not exist
+ intercept[AnalysisException] {
+ sql("ALTER TABLE does_not_exist ADD IF NOT EXISTS PARTITION (d='4')")
+ }
+ // partition to add already exists
+ intercept[AnalysisException] {
+ sql("ALTER TABLE tab1 ADD PARTITION (d='4')")
+ }
+ maybeWrapException(isDatasourceTable) {
+ sql("ALTER TABLE tab1 ADD IF NOT EXISTS PARTITION (d='4')")
+ }
+ if (!isDatasourceTable) {
+ assert(catalog.listPartitions(tableIdent).map(_.spec).toSet ==
+ Set(part1, part2, part3, part4))
+ }
+ }
+
+ private def testDropPartitions(isDatasourceTable: Boolean): Unit = {
+ val catalog = sqlContext.sessionState.catalog
+ val tableIdent = TableIdentifier("tab1", Some("dbx"))
+ val part1 = Map("a" -> "1")
+ val part2 = Map("b" -> "2")
+ val part3 = Map("c" -> "3")
+ val part4 = Map("d" -> "4")
+ createDatabase(catalog, "dbx")
+ createTable(catalog, tableIdent)
+ createTablePartition(catalog, part1, tableIdent)
+ createTablePartition(catalog, part2, tableIdent)
+ createTablePartition(catalog, part3, tableIdent)
+ createTablePartition(catalog, part4, tableIdent)
+ assert(catalog.listPartitions(tableIdent).map(_.spec).toSet ==
+ Set(part1, part2, part3, part4))
+ if (isDatasourceTable) {
+ convertToDatasourceTable(catalog, tableIdent)
+ }
+ maybeWrapException(isDatasourceTable) {
+ sql("ALTER TABLE dbx.tab1 DROP IF EXISTS PARTITION (d='4'), PARTITION (c='3')")
+ }
+ if (!isDatasourceTable) {
+ assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2))
+ }
+ // drop partitions without explicitly specifying database
+ catalog.setCurrentDatabase("dbx")
+ maybeWrapException(isDatasourceTable) {
+ sql("ALTER TABLE tab1 DROP IF EXISTS PARTITION (b='2')")
+ }
+ if (!isDatasourceTable) {
+ assert(catalog.listPartitions(tableIdent).map(_.spec) == Seq(part1))
+ }
+ // table to alter does not exist
+ intercept[AnalysisException] {
+ sql("ALTER TABLE does_not_exist DROP IF EXISTS PARTITION (b='2')")
+ }
+ // partition to drop does not exist
+ intercept[AnalysisException] {
+ sql("ALTER TABLE tab1 DROP PARTITION (x='300')")
+ }
+ maybeWrapException(isDatasourceTable) {
+ sql("ALTER TABLE tab1 DROP IF EXISTS PARTITION (x='300')")
+ }
+ if (!isDatasourceTable) {
+ assert(catalog.listPartitions(tableIdent).map(_.spec) == Seq(part1))
+ }
+ }
+
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala
index 1fa15730bc..dac56d3936 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala
@@ -22,8 +22,6 @@ import java.io.File
import org.apache.hadoop.fs.FileStatus
import org.apache.hadoop.mapreduce.Job
-import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet, PredicateHelper}
@@ -34,8 +32,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StructType}
-import org.apache.spark.util.{SerializableConfiguration, Utils}
-import org.apache.spark.util.collection.BitSet
+import org.apache.spark.util.Utils
class FileSourceStrategySuite extends QueryTest with SharedSQLContext with PredicateHelper {
import testImplicits._
@@ -76,7 +73,8 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
"file2" -> 5,
"file3" -> 5))
- withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "10") {
+ withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "11",
+ SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "1") {
checkScan(table.select('c1)) { partitions =>
// 5 byte files should be laid out [(5, 5), (5)]
assert(partitions.size == 2, "when checking partitions")
@@ -98,11 +96,12 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
createTable(
files = Seq(
"file1" -> 15,
- "file2" -> 4))
+ "file2" -> 3))
- withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "10") {
+ withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "10",
+ SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "1") {
checkScan(table.select('c1)) { partitions =>
- // Files should be laid out [(0-5), (5-10, 4)]
+ // Files should be laid out [(0-10), (10-15, 4)]
assert(partitions.size == 2, "when checking partitions")
assert(partitions(0).files.size == 1, "when checking partition 1")
assert(partitions(1).files.size == 2, "when checking partition 2")
@@ -121,6 +120,53 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
}
}
+ test("Unpartitioned table, many files that get split") {
+ val table =
+ createTable(
+ files = Seq(
+ "file1" -> 2,
+ "file2" -> 2,
+ "file3" -> 1,
+ "file4" -> 1,
+ "file5" -> 1,
+ "file6" -> 1))
+
+ withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "4",
+ SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "1") {
+ checkScan(table.select('c1)) { partitions =>
+ // Files should be laid out [(file1), (file2, file3), (file4, file5), (file6)]
+ assert(partitions.size == 4, "when checking partitions")
+ assert(partitions(0).files.size == 1, "when checking partition 1")
+ assert(partitions(1).files.size == 2, "when checking partition 2")
+ assert(partitions(2).files.size == 2, "when checking partition 3")
+ assert(partitions(3).files.size == 1, "when checking partition 4")
+
+ // First partition reads (file1)
+ assert(partitions(0).files(0).start == 0)
+ assert(partitions(0).files(0).length == 2)
+
+ // Second partition reads (file2, file3)
+ assert(partitions(1).files(0).start == 0)
+ assert(partitions(1).files(0).length == 2)
+ assert(partitions(1).files(1).start == 0)
+ assert(partitions(1).files(1).length == 1)
+
+ // Third partition reads (file4, file5)
+ assert(partitions(2).files(0).start == 0)
+ assert(partitions(2).files(0).length == 1)
+ assert(partitions(2).files(1).start == 0)
+ assert(partitions(2).files(1).length == 1)
+
+ // Final partition reads (file6)
+ assert(partitions(3).files(0).start == 0)
+ assert(partitions(3).files(0).length == 1)
+ }
+
+ checkPartitionSchema(StructType(Nil))
+ checkDataSchema(StructType(Nil).add("c1", IntegerType))
+ }
+ }
+
test("partitioned table") {
val table =
createTable(
@@ -147,6 +193,34 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
checkDataFilters(Set(IsNotNull("c1"), EqualTo("c1", 1)))
}
+ test("partitioned table - case insensitive") {
+ withSQLConf("spark.sql.caseSensitive" -> "false") {
+ val table =
+ createTable(
+ files = Seq(
+ "p1=1/file1" -> 10,
+ "p1=2/file2" -> 10))
+
+ // Only one file should be read.
+ checkScan(table.where("P1 = 1")) { partitions =>
+ assert(partitions.size == 1, "when checking partitions")
+ assert(partitions.head.files.size == 1, "when files in partition 1")
+ }
+ // We don't need to reevaluate filters that are only on partitions.
+ checkDataFilters(Set.empty)
+
+ // Only one file should be read.
+ checkScan(table.where("P1 = 1 AND C1 = 1 AND (P1 + C1) = 1")) { partitions =>
+ assert(partitions.size == 1, "when checking partitions")
+ assert(partitions.head.files.size == 1, "when checking files in partition 1")
+ assert(partitions.head.files.head.partitionValues.getInt(0) == 1,
+ "when checking partition values")
+ }
+ // Only the filters that do not contain the partition column should be pushed down
+ checkDataFilters(Set(IsNotNull("c1"), EqualTo("c1", 1)))
+ }
+ }
+
test("partitioned table - after scan filters") {
val table =
createTable(
@@ -230,7 +304,8 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
/** Plans the query and calls the provided validation function with the planned partitioning. */
def checkScan(df: DataFrame)(func: Seq[FilePartition] => Unit): Unit = {
val fileScan = df.queryExecution.executedPlan.collect {
- case DataSourceScan(_, scan: FileScanRDD, _, _) => scan
+ case scan: DataSourceScan if scan.rdd.isInstanceOf[FileScanRDD] =>
+ scan.rdd.asInstanceOf[FileScanRDD]
}.headOption.getOrElse {
fail(s"No FileScan in query\n${df.queryExecution}")
}
@@ -315,28 +390,17 @@ class TestFileFormat extends FileFormat {
throw new NotImplementedError("JUST FOR TESTING")
}
- override def buildInternalScan(
- sqlContext: SQLContext,
- dataSchema: StructType,
- requiredColumns: Array[String],
- filters: Array[Filter],
- bucketSet: Option[BitSet],
- inputFiles: Seq[FileStatus],
- broadcastedConf: Broadcast[SerializableConfiguration],
- options: Map[String, String]): RDD[InternalRow] = {
- throw new NotImplementedError("JUST FOR TESTING")
- }
-
override def buildReader(
sqlContext: SQLContext,
- partitionSchema: StructType,
dataSchema: StructType,
+ partitionSchema: StructType,
+ requiredSchema: StructType,
filters: Seq[Filter],
options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = {
// Record the arguments so they can be checked in the test case.
LastArguments.partitionSchema = partitionSchema
- LastArguments.dataSchema = dataSchema
+ LastArguments.dataSchema = requiredSchema
LastArguments.filters = filters
LastArguments.options = options
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
index 3a7cb25b4f..23d422635b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.csv
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._
-class InferSchemaSuite extends SparkFunSuite {
+class CSVInferSchemaSuite extends SparkFunSuite {
test("String fields types are inferred correctly from null types") {
assert(CSVInferSchema.inferField(NullType, "") == NullType)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala
index c0c38c6787..aaeecef5f3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.datasources.csv
import org.apache.spark.SparkFunSuite
/**
- * test cases for StringIteratorReader
- */
+ * test cases for StringIteratorReader
+ */
class CSVParserSuite extends SparkFunSuite {
private def readAll(iter: Iterator[String]) = {
@@ -46,7 +46,7 @@ class CSVParserSuite extends SparkFunSuite {
var numRead = 0
var n = 0
do { // try to fill cbuf
- var off = 0
+ var off = 0
var len = cbuf.length
n = reader.read(cbuf, off, len)
@@ -81,7 +81,7 @@ class CSVParserSuite extends SparkFunSuite {
test("Regular case") {
val input = List("This is a string", "This is another string", "Small", "", "\"quoted\"")
val read = readAll(input.toIterator)
- assert(read === input.mkString("\n") ++ ("\n"))
+ assert(read === input.mkString("\n") ++ "\n")
}
test("Empty iter") {
@@ -93,12 +93,12 @@ class CSVParserSuite extends SparkFunSuite {
test("Embedded new line") {
val input = List("This is a string", "This is another string", "Small\n", "", "\"quoted\"")
val read = readAll(input.toIterator)
- assert(read === input.mkString("\n") ++ ("\n"))
+ assert(read === input.mkString("\n") ++ "\n")
}
test("Buffer Regular case") {
val input = List("This is a string", "This is another string", "Small", "", "\"quoted\"")
- val output = input.mkString("\n") ++ ("\n")
+ val output = input.mkString("\n") ++ "\n"
for(i <- 1 to output.length + 5) {
val read = readBufAll(input.toIterator, i)
assert(read === output)
@@ -116,7 +116,7 @@ class CSVParserSuite extends SparkFunSuite {
test("Buffer Embedded new line") {
val input = List("This is a string", "This is another string", "Small\n", "", "\"quoted\"")
- val output = input.mkString("\n") ++ ("\n")
+ val output = input.mkString("\n") ++ "\n"
for(i <- 1 to output.length + 5) {
val read = readBufAll(input.toIterator, 1)
assert(read === output)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 58d9d69d9a..9baae80f15 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -45,6 +45,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
private val disableCommentsFile = "disable_comments.csv"
private val boolFile = "bool.csv"
private val simpleSparseFile = "simple_sparse.csv"
+ private val unescapedQuotesFile = "unescaped-quotes.csv"
private def testFile(fileName: String): String = {
Thread.currentThread().getContextClassLoader.getResource(fileName).toString
@@ -140,6 +141,17 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
verifyCars(cars, withHeader = true)
}
+ test("parse unescaped quotes with maxCharsPerColumn") {
+ val rows = sqlContext.read
+ .format("csv")
+ .option("maxCharsPerColumn", "4")
+ .load(testFile(unescapedQuotesFile))
+
+ val expectedRows = Seq(Row("\"a\"b", "ccc", "ddd"), Row("ab", "cc\"c", "ddd\""))
+
+ checkAnswer(rows, expectedRows)
+ }
+
test("bad encoding name") {
val exception = intercept[UnsupportedCharsetException] {
sqlContext
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index c108d81b18..e17340c70b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -745,8 +745,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
)
}
- test("Loading a JSON dataset floatAsBigDecimal returns schema with float types as BigDecimal") {
- val jsonDF = sqlContext.read.option("floatAsBigDecimal", "true").json(primitiveFieldAndType)
+ test("Loading a JSON dataset prefersDecimal returns schema with float types as BigDecimal") {
+ val jsonDF = sqlContext.read.option("prefersDecimal", "true").json(primitiveFieldAndType)
val expectedSchema = StructType(
StructField("bigInteger", DecimalType(20, 0), true) ::
@@ -773,6 +773,72 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
)
}
+ test("Find compatible types even if inferred DecimalType is not capable of other IntegralType") {
+ val mixedIntegerAndDoubleRecords = sparkContext.parallelize(
+ """{"a": 3, "b": 1.1}""" ::
+ s"""{"a": 3.1, "b": 0.${"0" * 38}1}""" :: Nil)
+ val jsonDF = sqlContext.read
+ .option("prefersDecimal", "true")
+ .json(mixedIntegerAndDoubleRecords)
+
+ // The values in `a` field will be decimals as they fit in decimal. For `b` field,
+ // they will be doubles as `1.0E-39D` does not fit.
+ val expectedSchema = StructType(
+ StructField("a", DecimalType(21, 1), true) ::
+ StructField("b", DoubleType, true) :: Nil)
+
+ assert(expectedSchema === jsonDF.schema)
+ checkAnswer(
+ jsonDF,
+ Row(BigDecimal("3"), 1.1D) ::
+ Row(BigDecimal("3.1"), 1.0E-39D) :: Nil
+ )
+ }
+
+ test("Infer big integers correctly even when it does not fit in decimal") {
+ val jsonDF = sqlContext.read
+ .json(bigIntegerRecords)
+
+ // The value in `a` field will be a double as it does not fit in decimal. For `b` field,
+ // it will be a decimal as `92233720368547758070`.
+ val expectedSchema = StructType(
+ StructField("a", DoubleType, true) ::
+ StructField("b", DecimalType(20, 0), true) :: Nil)
+
+ assert(expectedSchema === jsonDF.schema)
+ checkAnswer(jsonDF, Row(1.0E38D, BigDecimal("92233720368547758070")))
+ }
+
+ test("Infer floating-point values correctly even when it does not fit in decimal") {
+ val jsonDF = sqlContext.read
+ .option("prefersDecimal", "true")
+ .json(floatingValueRecords)
+
+ // The value in `a` field will be a double as it does not fit in decimal. For `b` field,
+ // it will be a decimal as `0.01` by having a precision equal to the scale.
+ val expectedSchema = StructType(
+ StructField("a", DoubleType, true) ::
+ StructField("b", DecimalType(2, 2), true):: Nil)
+
+ assert(expectedSchema === jsonDF.schema)
+ checkAnswer(jsonDF, Row(1.0E-39D, BigDecimal(0.01)))
+
+ val mergedJsonDF = sqlContext.read
+ .option("prefersDecimal", "true")
+ .json(floatingValueRecords ++ bigIntegerRecords)
+
+ val expectedMergedSchema = StructType(
+ StructField("a", DoubleType, true) ::
+ StructField("b", DecimalType(22, 2), true):: Nil)
+
+ assert(expectedMergedSchema === mergedJsonDF.schema)
+ checkAnswer(
+ mergedJsonDF,
+ Row(1.0E-39D, BigDecimal(0.01)) ::
+ Row(1.0E38D, BigDecimal("92233720368547758070")) :: Nil
+ )
+ }
+
test("Loading a JSON dataset from a text file with SQL") {
val dir = Utils.createTempDir()
dir.delete()
@@ -1598,4 +1664,19 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
)
}
}
+
+ test("wide nested json table") {
+ val nested = (1 to 100).map { i =>
+ s"""
+ |"c$i": $i
+ """.stripMargin
+ }.mkString(", ")
+ val json = s"""
+ |{"a": [{$nested}], "b": [{$nested}]}
+ """.stripMargin
+ val rdd = sqlContext.sparkContext.makeRDD(Seq(json))
+ val df = sqlContext.read.json(rdd)
+ assert(df.schema.size === 2)
+ df.collect()
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
index b2eff816ee..2873c6a881 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
@@ -214,6 +214,14 @@ private[json] trait TestJsonData {
"""{"a": {"b": 1}}""" ::
"""{"a": []}""" :: Nil)
+ def floatingValueRecords: RDD[String] =
+ sqlContext.sparkContext.parallelize(
+ s"""{"a": 0.${"0" * 38}1, "b": 0.01}""" :: Nil)
+
+ def bigIntegerRecords: RDD[String] =
+ sqlContext.sparkContext.parallelize(
+ s"""{"a": 1${"0" * 38}, "b": 92233720368547758070}""" :: Nil)
+
lazy val singleRow: RDD[String] = sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil)
def empty: RDD[String] = sqlContext.sparkContext.parallelize(Seq[String]())
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index 9746187d22..581095d3dc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -445,55 +445,6 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
}
}
- testQuietly("SPARK-6352 DirectParquetOutputCommitter") {
- val clonedConf = new Configuration(hadoopConfiguration)
-
- // Write to a parquet file and let it fail.
- // _temporary should be missing if direct output committer works.
- try {
- hadoopConfiguration.set("spark.sql.parquet.output.committer.class",
- classOf[DirectParquetOutputCommitter].getCanonicalName)
- sqlContext.udf.register("div0", (x: Int) => x / 0)
- withTempPath { dir =>
- intercept[org.apache.spark.SparkException] {
- sqlContext.sql("select div0(1) as div0").write.parquet(dir.getCanonicalPath)
- }
- val path = new Path(dir.getCanonicalPath, "_temporary")
- val fs = path.getFileSystem(hadoopConfiguration)
- assert(!fs.exists(path))
- }
- } finally {
- // Hadoop 1 doesn't have `Configuration.unset`
- hadoopConfiguration.clear()
- clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue))
- }
- }
-
- testQuietly("SPARK-9849 DirectParquetOutputCommitter qualified name backwards compatiblity") {
- val clonedConf = new Configuration(hadoopConfiguration)
-
- // Write to a parquet file and let it fail.
- // _temporary should be missing if direct output committer works.
- try {
- hadoopConfiguration.set("spark.sql.parquet.output.committer.class",
- "org.apache.spark.sql.parquet.DirectParquetOutputCommitter")
- sqlContext.udf.register("div0", (x: Int) => x / 0)
- withTempPath { dir =>
- intercept[org.apache.spark.SparkException] {
- sqlContext.sql("select div0(1) as div0").write.parquet(dir.getCanonicalPath)
- }
- val path = new Path(dir.getCanonicalPath, "_temporary")
- val fs = path.getFileSystem(hadoopConfiguration)
- assert(!fs.exists(path))
- }
- } finally {
- // Hadoop 1 doesn't have `Configuration.unset`
- hadoopConfiguration.clear()
- clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue))
- }
- }
-
-
test("SPARK-8121: spark.sql.parquet.output.committer.class shouldn't be overridden") {
withTempPath { dir =>
val clonedConf = new Configuration(hadoopConfiguration)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
index 2f806ebba6..7d206e7bc4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
@@ -579,6 +579,16 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
assert(CatalystReadSupport.expandUDT(schema) === expected)
}
+
+ test("read/write wide table") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+
+ val df = sqlContext.range(1000).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*)
+ df.write.mode(SaveMode.Overwrite).parquet(path)
+ checkAnswer(sqlContext.read.parquet(path), df)
+ }
+ }
}
object TestingUDT {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
index 22189477d2..8aa0114d98 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
@@ -19,10 +19,23 @@ package org.apache.spark.sql.execution.debug
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.test.SQLTestData.TestData
class DebuggingSuite extends SparkFunSuite with SharedSQLContext {
test("DataFrame.debug()") {
testData.debug()
}
+
+ test("Dataset.debug()") {
+ import testImplicits._
+ testData.as[TestData].debug()
+ }
+
+ test("debugCodegen") {
+ val res = codegenString(sqlContext.range(10).groupBy("id").count().queryExecution.executedPlan)
+ assert(res.contains("Subtree 1 / 2"))
+ assert(res.contains("Subtree 2 / 2"))
+ assert(res.contains("Object[]"))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
index 985a96f684..8cdfa8afd0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
-import org.apache.spark.sql.catalyst.plans.{Inner, LeftSemi}
+import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftAnti, LeftSemi}
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
import org.apache.spark.sql.execution.exchange.EnsureRequirements
@@ -28,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
-class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
+class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext {
private lazy val left = sqlContext.createDataFrame(
sparkContext.parallelize(Seq(
@@ -58,14 +58,20 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
LessThan(left.col("b").expr, right.col("d").expr))
}
+ private lazy val conditionNEQ = {
+ And((left.col("a") < right.col("c")).expr,
+ LessThan(left.col("b").expr, right.col("d").expr))
+ }
+
// Note: the input dataframes and expression must be evaluated lazily because
// the SQLContext should be used only within a test to keep SQL tests stable
- private def testLeftSemiJoin(
+ private def testExistenceJoin(
testName: String,
+ joinType: JoinType,
leftRows: => DataFrame,
rightRows: => DataFrame,
condition: => Expression,
- expectedAnswer: Seq[Product]): Unit = {
+ expectedAnswer: Seq[Row]): Unit = {
def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
@@ -73,25 +79,26 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
}
test(s"$testName using ShuffledHashJoin") {
- extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(left.sqlContext.sessionState.conf).apply(
ShuffledHashJoin(
- leftKeys, rightKeys, LeftSemi, BuildRight, boundCondition, left, right)),
- expectedAnswer.map(Row.fromTuple),
+ leftKeys, rightKeys, joinType, BuildRight, boundCondition, left, right)),
+ expectedAnswer,
sortAnswers = true)
}
}
}
test(s"$testName using BroadcastHashJoin") {
- extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- BroadcastHashJoin(
- leftKeys, rightKeys, LeftSemi, BuildRight, boundCondition, left, right),
- expectedAnswer.map(Row.fromTuple),
+ EnsureRequirements(left.sqlContext.sessionState.conf).apply(
+ BroadcastHashJoin(
+ leftKeys, rightKeys, joinType, BuildRight, boundCondition, left, right)),
+ expectedAnswer,
sortAnswers = true)
}
}
@@ -100,8 +107,9 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
test(s"$testName using BroadcastNestedLoopJoin build left") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- BroadcastNestedLoopJoin(left, right, BuildLeft, LeftSemi, Some(condition)),
- expectedAnswer.map(Row.fromTuple),
+ EnsureRequirements(left.sqlContext.sessionState.conf).apply(
+ BroadcastNestedLoopJoin(left, right, BuildLeft, joinType, Some(condition))),
+ expectedAnswer,
sortAnswers = true)
}
}
@@ -109,21 +117,43 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
test(s"$testName using BroadcastNestedLoopJoin build right") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- BroadcastNestedLoopJoin(left, right, BuildRight, LeftSemi, Some(condition)),
- expectedAnswer.map(Row.fromTuple),
+ EnsureRequirements(left.sqlContext.sessionState.conf).apply(
+ BroadcastNestedLoopJoin(left, right, BuildRight, joinType, Some(condition))),
+ expectedAnswer,
sortAnswers = true)
}
}
}
- testLeftSemiJoin(
- "basic test",
+ testExistenceJoin(
+ "basic test for left semi join",
+ LeftSemi,
+ left,
+ right,
+ condition,
+ Seq(Row(2, 1.0), Row(2, 1.0)))
+
+ testExistenceJoin(
+ "basic test for left semi non equal join",
+ LeftSemi,
+ left,
+ right,
+ conditionNEQ,
+ Seq(Row(1, 2.0), Row(1, 2.0), Row(2, 1.0), Row(2, 1.0)))
+
+ testExistenceJoin(
+ "basic test for anti join",
+ LeftAnti,
left,
right,
condition,
- Seq(
- (2, 1.0),
- (2, 1.0)
- )
- )
+ Seq(Row(1, 2.0), Row(1, 2.0), Row(3, 3.0), Row(6, null), Row(null, 5.0), Row(null, null)))
+
+ testExistenceJoin(
+ "basic test for anti non equal join",
+ LeftAnti,
+ left,
+ right,
+ conditionNEQ,
+ Seq(Row(3, 3.0), Row(6, null), Row(null, 5.0), Row(null, null)))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index e19b4ff1e2..371a9ed617 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -19,33 +19,43 @@ package org.apache.spark.sql.execution.joins
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.util.collection.CompactBuffer
class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
+ val mm = new TaskMemoryManager(
+ new StaticMemoryManager(
+ new SparkConf().set("spark.memory.offHeap.enabled", "false"),
+ Long.MaxValue,
+ Long.MaxValue,
+ 1),
+ 0)
+
test("UnsafeHashedRelation") {
val schema = StructType(StructField("a", IntegerType, true) :: Nil)
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
val toUnsafe = UnsafeProjection.create(schema)
- val unsafeData = data.map(toUnsafe(_).copy()).toArray
+ val unsafeData = data.map(toUnsafe(_).copy())
+
val buildKey = Seq(BoundReference(0, IntegerType, false))
- val keyGenerator = UnsafeProjection.create(buildKey)
- val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1)
+ val hashed = UnsafeHashedRelation(unsafeData.iterator, buildKey, 1, mm)
assert(hashed.isInstanceOf[UnsafeHashedRelation])
- assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0)))
- assert(hashed.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1)))
+ assert(hashed.get(unsafeData(0)).toArray === Array(unsafeData(0)))
+ assert(hashed.get(unsafeData(1)).toArray === Array(unsafeData(1)))
assert(hashed.get(toUnsafe(InternalRow(10))) === null)
val data2 = CompactBuffer[InternalRow](unsafeData(2).copy())
data2 += unsafeData(2).copy()
- assert(hashed.get(unsafeData(2)) === data2)
+ assert(hashed.get(unsafeData(2)).toArray === data2.toArray)
val os = new ByteArrayOutputStream()
val out = new ObjectOutputStream(os)
@@ -54,10 +64,10 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
val hashed2 = new UnsafeHashedRelation()
hashed2.readExternal(in)
- assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0)))
- assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1)))
+ assert(hashed2.get(unsafeData(0)).toArray === Array(unsafeData(0)))
+ assert(hashed2.get(unsafeData(1)).toArray === Array(unsafeData(1)))
assert(hashed2.get(toUnsafe(InternalRow(10))) === null)
- assert(hashed2.get(unsafeData(2)) === data2)
+ assert(hashed2.get(unsafeData(2)).toArray === data2)
val os2 = new ByteArrayOutputStream()
val out2 = new ObjectOutputStream(os2)
@@ -69,10 +79,17 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
}
test("test serialization empty hash map") {
+ val taskMemoryManager = new TaskMemoryManager(
+ new StaticMemoryManager(
+ new SparkConf().set("spark.memory.offHeap.enabled", "false"),
+ Long.MaxValue,
+ Long.MaxValue,
+ 1),
+ 0)
+ val binaryMap = new BytesToBytesMap(taskMemoryManager, 1, 1)
val os = new ByteArrayOutputStream()
val out = new ObjectOutputStream(os)
- val hashed = new UnsafeHashedRelation(
- new java.util.HashMap[UnsafeRow, CompactBuffer[UnsafeRow]])
+ val hashed = new UnsafeHashedRelation(1, binaryMap)
hashed.writeExternal(out)
out.flush()
val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
@@ -91,31 +108,45 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray))
}
- test("LongArrayRelation") {
+ test("LongToUnsafeRowMap") {
val unsafeProj = UnsafeProjection.create(
Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, true)))
val rows = (0 until 100).map(i => unsafeProj(InternalRow(i, i + 1)).copy())
- val keyProj = UnsafeProjection.create(Seq(BoundReference(0, IntegerType, false)))
- val longRelation = LongHashedRelation(rows.iterator, keyProj, 100)
- assert(longRelation.isInstanceOf[LongArrayRelation])
- val longArrayRelation = longRelation.asInstanceOf[LongArrayRelation]
+ val key = Seq(BoundReference(0, IntegerType, false))
+ val longRelation = LongHashedRelation(rows.iterator, key, 10, mm)
+ assert(longRelation.keyIsUnique)
(0 until 100).foreach { i =>
- val row = longArrayRelation.getValue(i)
+ val row = longRelation.getValue(i)
assert(row.getInt(0) === i)
assert(row.getInt(1) === i + 1)
}
+ val longRelation2 = LongHashedRelation(rows.iterator ++ rows.iterator, key, 100, mm)
+ assert(!longRelation2.keyIsUnique)
+ (0 until 100).foreach { i =>
+ val rows = longRelation2.get(i).toArray
+ assert(rows.length === 2)
+ assert(rows(0).getInt(0) === i)
+ assert(rows(0).getInt(1) === i + 1)
+ assert(rows(1).getInt(0) === i)
+ assert(rows(1).getInt(1) === i + 1)
+ }
+
val os = new ByteArrayOutputStream()
val out = new ObjectOutputStream(os)
- longArrayRelation.writeExternal(out)
+ longRelation2.writeExternal(out)
out.flush()
val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
- val relation = new LongArrayRelation()
+ val relation = new LongHashedRelation()
relation.readExternal(in)
+ assert(!relation.keyIsUnique)
(0 until 100).foreach { i =>
- val row = longArrayRelation.getValue(i)
- assert(row.getInt(0) === i)
- assert(row.getInt(1) === i + 1)
+ val rows = relation.get(i).toArray
+ assert(rows.length === 2)
+ assert(rows(0).getInt(0) === i)
+ assert(rows(0).getInt(1) === i + 1)
+ assert(rows(1).getInt(0) === i)
+ assert(rows(1).getInt(1) === i + 1)
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala
index d5db9db36b..1328142704 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala
@@ -21,6 +21,7 @@ import java.io.{File, FileNotFoundException, IOException}
import java.net.URI
import java.util.ConcurrentModificationException
+import scala.language.implicitConversions
import scala.util.Random
import org.apache.hadoop.conf.Configuration
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala
new file mode 100644
index 0000000000..dd5f92248b
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent.{CountDownLatch, TimeUnit}
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.ProcessingTime
+import org.apache.spark.util.ManualClock
+
+class ProcessingTimeExecutorSuite extends SparkFunSuite {
+
+ test("nextBatchTime") {
+ val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(100))
+ assert(processingTimeExecutor.nextBatchTime(1) === 100)
+ assert(processingTimeExecutor.nextBatchTime(99) === 100)
+ assert(processingTimeExecutor.nextBatchTime(100) === 100)
+ assert(processingTimeExecutor.nextBatchTime(101) === 200)
+ assert(processingTimeExecutor.nextBatchTime(150) === 200)
+ }
+
+ private def testBatchTermination(intervalMs: Long): Unit = {
+ var batchCounts = 0
+ val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(intervalMs))
+ processingTimeExecutor.execute(() => {
+ batchCounts += 1
+ // If the batch termination works well, batchCounts should be 3 after `execute`
+ batchCounts < 3
+ })
+ assert(batchCounts === 3)
+ }
+
+ test("batch termination") {
+ testBatchTermination(0)
+ testBatchTermination(10)
+ }
+
+ test("notifyBatchFallingBehind") {
+ val clock = new ManualClock()
+ @volatile var batchFallingBehindCalled = false
+ val latch = new CountDownLatch(1)
+ val t = new Thread() {
+ override def run(): Unit = {
+ val processingTimeExecutor = new ProcessingTimeExecutor(ProcessingTime(100), clock) {
+ override def notifyBatchFallingBehind(realElapsedTimeMs: Long): Unit = {
+ batchFallingBehindCalled = true
+ }
+ }
+ processingTimeExecutor.execute(() => {
+ latch.countDown()
+ clock.waitTillTime(200)
+ false
+ })
+ }
+ }
+ t.start()
+ // Wait until the batch is running so that we don't call `advance` too early
+ assert(latch.await(10, TimeUnit.SECONDS), "the batch has not yet started in 10 seconds")
+ clock.advance(200)
+ t.join()
+ assert(batchFallingBehindCalled === true)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
index df50cbde56..6be94eb24f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
@@ -33,7 +33,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{CompletionIterator, Utils}
class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll {
@@ -54,62 +54,93 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
}
test("versioning and immutability") {
- quietly {
- withSpark(new SparkContext(sparkConf)) { sc =>
- implicit val sqlContet = new SQLContext(sc)
- val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
- val increment = (store: StateStore, iter: Iterator[String]) => {
- iter.foreach { s =>
- store.update(
- stringToRow(s), oldRow => {
- val oldValue = oldRow.map(rowToInt).getOrElse(0)
- intToRow(oldValue + 1)
- })
- }
- store.commit()
- store.iterator().map(rowsToStringInt)
- }
- val opId = 0
- val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
- increment, path, opId, storeVersion = 0, keySchema, valueSchema)
- assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+ withSpark(new SparkContext(sparkConf)) { sc =>
+ val sqlContext = new SQLContext(sc)
+ val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
+ val opId = 0
+ val rdd1 =
+ makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(
+ increment)
+ assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+
+ // Generate next version of stores
+ val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment)
+ assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
+
+ // Make sure the previous RDD still has the same data.
+ assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+ }
+ }
- // Generate next version of stores
- val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore(
- increment, path, opId, storeVersion = 1, keySchema, valueSchema)
- assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
+ test("recovering from files") {
+ val opId = 0
+ val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
+
+ def makeStoreRDD(
+ sc: SparkContext,
+ seq: Seq[String],
+ storeVersion: Int): RDD[(String, Int)] = {
+ implicit val sqlContext = new SQLContext(sc)
+ makeRDD(sc, Seq("a")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion, keySchema, valueSchema)(increment)
+ }
- // Make sure the previous RDD still has the same data.
- assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+ // Generate RDDs and state store data
+ withSpark(new SparkContext(sparkConf)) { sc =>
+ for (i <- 1 to 20) {
+ require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i))
}
}
+
+ // With a new context, try using the earlier state store data
+ withSpark(new SparkContext(sparkConf)) { sc =>
+ assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21))
+ }
}
- test("recovering from files") {
- quietly {
- val opId = 0
+ test("usage with iterators - only gets and only puts") {
+ withSpark(new SparkContext(sparkConf)) { sc =>
+ implicit val sqlContext = new SQLContext(sc)
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
+ val opId = 0
- def makeStoreRDD(
- sc: SparkContext,
- seq: Seq[String],
- storeVersion: Int): RDD[(String, Int)] = {
- implicit val sqlContext = new SQLContext(sc)
- makeRDD(sc, Seq("a")).mapPartitionWithStateStore(
- increment, path, opId, storeVersion, keySchema, valueSchema)
+ // Returns an iterator of the incremented value made into the store
+ def iteratorOfPuts(store: StateStore, iter: Iterator[String]): Iterator[(String, Int)] = {
+ val resIterator = iter.map { s =>
+ val key = stringToRow(s)
+ val oldValue = store.get(key).map(rowToInt).getOrElse(0)
+ val newValue = oldValue + 1
+ store.put(key, intToRow(newValue))
+ (s, newValue)
+ }
+ CompletionIterator[(String, Int), Iterator[(String, Int)]](resIterator, {
+ store.commit()
+ })
}
- // Generate RDDs and state store data
- withSpark(new SparkContext(sparkConf)) { sc =>
- for (i <- 1 to 20) {
- require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i))
+ def iteratorOfGets(
+ store: StateStore,
+ iter: Iterator[String]): Iterator[(String, Option[Int])] = {
+ iter.map { s =>
+ val key = stringToRow(s)
+ val value = store.get(key).map(rowToInt)
+ (s, value)
}
}
- // With a new context, try using the earlier state store data
- withSpark(new SparkContext(sparkConf)) { sc =>
- assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21))
- }
+ val rddOfGets1 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets)
+ assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None))
+
+ val rddOfPuts = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfPuts)
+ assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1))
+
+ val rddOfGets2 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(iteratorOfGets)
+ assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None))
}
}
@@ -124,14 +155,12 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1")
coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2")
- eventually(timeout(10 seconds)) {
- assert(
- coordinatorRef.getLocation(StateStoreId(path, opId, 0)) ===
- Some(ExecutorCacheTaskLocation("host1", "exec1").toString))
- }
+ assert(
+ coordinatorRef.getLocation(StateStoreId(path, opId, 0)) ===
+ Some(ExecutorCacheTaskLocation("host1", "exec1").toString))
- val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
- increment, path, opId, storeVersion = 0, keySchema, valueSchema)
+ val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment)
require(rdd.partitions.length === 2)
assert(
@@ -150,27 +179,16 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
test("distributed test") {
quietly {
withSpark(new SparkContext(sparkConf.setMaster("local-cluster[2, 1, 1024]"))) { sc =>
- implicit val sqlContet = new SQLContext(sc)
+ implicit val sqlContext = new SQLContext(sc)
val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
- val increment = (store: StateStore, iter: Iterator[String]) => {
- iter.foreach { s =>
- store.update(
- stringToRow(s), oldRow => {
- val oldValue = oldRow.map(rowToInt).getOrElse(0)
- intToRow(oldValue + 1)
- })
- }
- store.commit()
- store.iterator().map(rowsToStringInt)
- }
val opId = 0
- val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
- increment, path, opId, storeVersion = 0, keySchema, valueSchema)
+ val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment)
assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
// Generate next version of stores
- val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore(
- increment, path, opId, storeVersion = 1, keySchema, valueSchema)
+ val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore(
+ sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment)
assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
// Make sure the previous RDD still has the same data.
@@ -185,11 +203,9 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
private val increment = (store: StateStore, iter: Iterator[String]) => {
iter.foreach { s =>
- store.update(
- stringToRow(s), oldRow => {
- val oldValue = oldRow.map(rowToInt).getOrElse(0)
- intToRow(oldValue + 1)
- })
+ val key = stringToRow(s)
+ val oldValue = store.get(key).map(rowToInt).getOrElse(0)
+ store.put(key, intToRow(oldValue + 1))
}
store.commit()
store.iterator().map(rowsToStringInt)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index 22b2f4f75d..dd23925716 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -51,7 +51,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
StateStore.stop()
}
- test("update, remove, commit, and all data iterator") {
+ test("get, put, remove, commit, and all data iterator") {
val provider = newStoreProvider()
// Verify state before starting a new set of updates
@@ -67,7 +67,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
}
// Verify state after updating
- update(store, "a", 1)
+ put(store, "a", 1)
intercept[IllegalStateException] {
store.iterator()
}
@@ -77,8 +77,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
assert(provider.latestIterator().isEmpty)
// Make updates, commit and then verify state
- update(store, "b", 2)
- update(store, "aa", 3)
+ put(store, "b", 2)
+ put(store, "aa", 3)
remove(store, _.startsWith("a"))
assert(store.commit() === 1)
@@ -101,7 +101,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
val reloadedProvider = new HDFSBackedStateStoreProvider(
store.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration)
val reloadedStore = reloadedProvider.getStore(1)
- update(reloadedStore, "c", 4)
+ put(reloadedStore, "c", 4)
assert(reloadedStore.commit() === 2)
assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4))
assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4))
@@ -112,6 +112,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
test("updates iterator with all combos of updates and removes") {
val provider = newStoreProvider()
var currentVersion: Int = 0
+
def withStore(body: StateStore => Unit): Unit = {
val store = provider.getStore(currentVersion)
body(store)
@@ -120,9 +121,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
// New data should be seen in updates as value added, even if they had multiple updates
withStore { store =>
- update(store, "a", 1)
- update(store, "aa", 1)
- update(store, "aa", 2)
+ put(store, "a", 1)
+ put(store, "aa", 1)
+ put(store, "aa", 2)
store.commit()
assert(updatesToSet(store.updates()) === Set(Added("a", 1), Added("aa", 2)))
assert(rowsToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2))
@@ -131,8 +132,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
// Multiple updates to same key should be collapsed in the updates as a single value update
// Keys that have not been updated should not appear in the updates
withStore { store =>
- update(store, "a", 4)
- update(store, "a", 6)
+ put(store, "a", 4)
+ put(store, "a", 6)
store.commit()
assert(updatesToSet(store.updates()) === Set(Updated("a", 6)))
assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2))
@@ -140,9 +141,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
// Keys added, updated and finally removed before commit should not appear in updates
withStore { store =>
- update(store, "b", 4) // Added, finally removed
- update(store, "bb", 5) // Added, updated, finally removed
- update(store, "bb", 6)
+ put(store, "b", 4) // Added, finally removed
+ put(store, "bb", 5) // Added, updated, finally removed
+ put(store, "bb", 6)
remove(store, _.startsWith("b"))
store.commit()
assert(updatesToSet(store.updates()) === Set.empty)
@@ -153,7 +154,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
// Removed, but re-added data should be seen in updates as a value update
withStore { store =>
remove(store, _.startsWith("a"))
- update(store, "a", 10)
+ put(store, "a", 10)
store.commit()
assert(updatesToSet(store.updates()) === Set(Updated("a", 10), Removed("aa")))
assert(rowsToSet(store.iterator()) === Set("a" -> 10))
@@ -163,14 +164,14 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
test("cancel") {
val provider = newStoreProvider()
val store = provider.getStore(0)
- update(store, "a", 1)
+ put(store, "a", 1)
store.commit()
assert(rowsToSet(store.iterator()) === Set("a" -> 1))
// cancelUpdates should not change the data in the files
val store1 = provider.getStore(1)
- update(store1, "b", 1)
- store1.cancel()
+ put(store1, "b", 1)
+ store1.abort()
assert(getDataFromFiles(provider) === Set("a" -> 1))
}
@@ -183,7 +184,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
// Prepare some data in the stoer
val store = provider.getStore(0)
- update(store, "a", 1)
+ put(store, "a", 1)
assert(store.commit() === 1)
assert(rowsToSet(store.iterator()) === Set("a" -> 1))
@@ -193,14 +194,14 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
// Update store version with some data
val store1 = provider.getStore(1)
- update(store1, "b", 1)
+ put(store1, "b", 1)
assert(store1.commit() === 2)
assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1))
assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1))
// Overwrite the version with other data
val store2 = provider.getStore(1)
- update(store2, "c", 1)
+ put(store2, "c", 1)
assert(store2.commit() === 2)
assert(rowsToSet(store2.iterator()) === Set("a" -> 1, "c" -> 1))
assert(getDataFromFiles(provider) === Set("a" -> 1, "c" -> 1))
@@ -213,7 +214,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
def updateVersionTo(targetVersion: Int): Unit = {
for (i <- currentVersion + 1 to targetVersion) {
val store = provider.getStore(currentVersion)
- update(store, "a", i)
+ put(store, "a", i)
store.commit()
currentVersion += 1
}
@@ -264,7 +265,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
for (i <- 1 to 20) {
val store = provider.getStore(i - 1)
- update(store, "a", i)
+ put(store, "a", i)
store.commit()
provider.doMaintenance() // do cleanup
}
@@ -284,7 +285,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
val provider = newStoreProvider(minDeltasForSnapshot = 5)
for (i <- 1 to 6) {
val store = provider.getStore(i - 1)
- update(store, "a", i)
+ put(store, "a", i)
store.commit()
provider.doMaintenance() // do cleanup
}
@@ -333,7 +334,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
// Increase version of the store
val store0 = StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf)
assert(store0.version === 0)
- update(store0, "a", 1)
+ put(store0, "a", 1)
store0.commit()
assert(StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf).version == 1)
@@ -345,13 +346,13 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf)
assert(StateStore.isLoaded(storeId))
- update(store1, "a", 2)
+ put(store1, "a", 2)
assert(store1.commit() === 2)
assert(rowsToSet(store1.iterator()) === Set("a" -> 2))
}
}
- test("maintenance") {
+ ignore("maintenance") {
val conf = new SparkConf()
.setMaster("local")
.setAppName("test")
@@ -371,7 +372,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
for (i <- 1 to 20) {
val store = StateStore.get(
storeId, keySchema, valueSchema, i - 1, storeConf, hadoopConf)
- update(store, "a", i)
+ put(store, "a", i)
store.commit()
}
eventually(timeout(10 seconds)) {
@@ -507,8 +508,12 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
store.remove(row => condition(rowToString(row)))
}
- private def update(store: StateStore, key: String, value: Int): Unit = {
- store.update(stringToRow(key), _ => intToRow(value))
+ private def put(store: StateStore, key: String, value: Int): Unit = {
+ store.put(stringToRow(key), intToRow(value))
+ }
+
+ private def get(store: StateStore, key: String): Option[Int] = {
+ store.get(stringToRow(key)).map(rowToInt)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
index 4262097e8f..31b63f2ce1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
@@ -612,23 +612,20 @@ class ColumnarBatchSuite extends SparkFunSuite {
val a2 = r2.getList(v._2).toArray
assert(a1.length == a2.length, "Seed = " + seed)
childType match {
- case DoubleType => {
+ case DoubleType =>
var i = 0
while (i < a1.length) {
assert(doubleEquals(a1(i).asInstanceOf[Double], a2(i).asInstanceOf[Double]),
"Seed = " + seed)
i += 1
}
- }
- case FloatType => {
+ case FloatType =>
var i = 0
while (i < a1.length) {
assert(doubleEquals(a1(i).asInstanceOf[Float], a2(i).asInstanceOf[Float]),
"Seed = " + seed)
i += 1
}
- }
-
case t: DecimalType =>
var i = 0
while (i < a1.length) {
@@ -640,7 +637,6 @@ class ColumnarBatchSuite extends SparkFunSuite {
}
i += 1
}
-
case _ => assert(a1 === a2, "Seed = " + seed)
}
case StructType(childFields) =>
@@ -756,4 +752,25 @@ class ColumnarBatchSuite extends SparkFunSuite {
}}
}
}
+
+ test("mutable ColumnarBatch rows") {
+ val NUM_ITERS = 10
+ val types = Array(
+ BooleanType, FloatType, DoubleType,
+ IntegerType, LongType, ShortType, DecimalType.IntDecimal, new DecimalType(30, 10))
+ for (i <- 0 to NUM_ITERS) {
+ val random = new Random(System.nanoTime())
+ val schema = RandomDataGenerator.randomSchema(random, numFields = 20, types)
+ val oldRow = RandomDataGenerator.randomRow(random, schema)
+ val newRow = RandomDataGenerator.randomRow(random, schema)
+
+ (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode =>
+ val batch = ColumnVectorUtils.toBatch(schema, memMode, (oldRow :: Nil).iterator.asJava)
+ val columnarBatchRow = batch.getRow(0)
+ newRow.toSeq.zipWithIndex.foreach(i => columnarBatchRow.update(i._2, i._1))
+ compareStruct(schema, columnarBatchRow, newRow, 0)
+ batch.close()
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala
index 2b89fa9f23..cc69199139 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala
@@ -26,7 +26,7 @@ class SQLConfEntrySuite extends SparkFunSuite {
test("intConf") {
val key = "spark.sql.SQLConfEntrySuite.int"
- val confEntry = SQLConfEntry.intConf(key)
+ val confEntry = SQLConfigBuilder(key).intConf.createWithDefault(1)
assert(conf.getConf(confEntry, 5) === 5)
conf.setConf(confEntry, 10)
@@ -45,7 +45,7 @@ class SQLConfEntrySuite extends SparkFunSuite {
test("longConf") {
val key = "spark.sql.SQLConfEntrySuite.long"
- val confEntry = SQLConfEntry.longConf(key)
+ val confEntry = SQLConfigBuilder(key).longConf.createWithDefault(1L)
assert(conf.getConf(confEntry, 5L) === 5L)
conf.setConf(confEntry, 10L)
@@ -64,7 +64,7 @@ class SQLConfEntrySuite extends SparkFunSuite {
test("booleanConf") {
val key = "spark.sql.SQLConfEntrySuite.boolean"
- val confEntry = SQLConfEntry.booleanConf(key)
+ val confEntry = SQLConfigBuilder(key).booleanConf.createWithDefault(true)
assert(conf.getConf(confEntry, false) === false)
conf.setConf(confEntry, true)
@@ -83,7 +83,7 @@ class SQLConfEntrySuite extends SparkFunSuite {
test("doubleConf") {
val key = "spark.sql.SQLConfEntrySuite.double"
- val confEntry = SQLConfEntry.doubleConf(key)
+ val confEntry = SQLConfigBuilder(key).doubleConf.createWithDefault(1d)
assert(conf.getConf(confEntry, 5.0) === 5.0)
conf.setConf(confEntry, 10.0)
@@ -102,7 +102,7 @@ class SQLConfEntrySuite extends SparkFunSuite {
test("stringConf") {
val key = "spark.sql.SQLConfEntrySuite.string"
- val confEntry = SQLConfEntry.stringConf(key)
+ val confEntry = SQLConfigBuilder(key).stringConf.createWithDefault(null)
assert(conf.getConf(confEntry, "abc") === "abc")
conf.setConf(confEntry, "abcd")
@@ -116,7 +116,10 @@ class SQLConfEntrySuite extends SparkFunSuite {
test("enumConf") {
val key = "spark.sql.SQLConfEntrySuite.enum"
- val confEntry = SQLConfEntry.enumConf(key, v => v, Set("a", "b", "c"), defaultValue = Some("a"))
+ val confEntry = SQLConfigBuilder(key)
+ .stringConf
+ .checkValues(Set("a", "b", "c"))
+ .createWithDefault("a")
assert(conf.getConf(confEntry) === "a")
conf.setConf(confEntry, "b")
@@ -135,8 +138,10 @@ class SQLConfEntrySuite extends SparkFunSuite {
test("stringSeqConf") {
val key = "spark.sql.SQLConfEntrySuite.stringSeq"
- val confEntry = SQLConfEntry.stringSeqConf("spark.sql.SQLConfEntrySuite.stringSeq",
- defaultValue = Some(Nil))
+ val confEntry = SQLConfigBuilder(key)
+ .stringConf
+ .toSequence
+ .createWithDefault(Nil)
assert(conf.getConf(confEntry, Seq("a", "b", "c")) === Seq("a", "b", "c"))
conf.setConf(confEntry, Seq("a", "b", "c", "d"))
@@ -147,4 +152,12 @@ class SQLConfEntrySuite extends SparkFunSuite {
assert(conf.getConfString(key) === "a,b,c,d,e")
assert(conf.getConf(confEntry, Seq("a", "b", "c")) === Seq("a", "b", "c", "d", "e"))
}
+
+ test("duplicate entry") {
+ val key = "spark.sql.SQLConfEntrySuite.duplicate"
+ SQLConfigBuilder(key).stringConf.createOptional
+ intercept[IllegalArgumentException] {
+ SQLConfigBuilder(key).stringConf.createOptional
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
index e944d328a3..e687e6a5ce 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
@@ -119,15 +119,10 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
}
intercept[IllegalArgumentException] {
- // This value less than Int.MinValue
+ // This value less than Long.MinValue
sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-90000000000g")
}
- // Test invalid input
- intercept[IllegalArgumentException] {
- // This value exceeds Long.MaxValue
- sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-1g")
- }
sqlContext.conf.clear()
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala
index 54ce98d195..3d69c8a187 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala
@@ -29,7 +29,7 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkException
import org.apache.spark.sql.{ContinuousQuery, Dataset, StreamTest}
-import org.apache.spark.sql.execution.streaming.{MemorySink, MemoryStream, StreamExecution, StreamingRelation}
+import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils
@@ -185,8 +185,8 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with
val q2 = stopRandomQueryAsync(100 milliseconds, withError = true)
testAwaitAnyTermination(
ExpectException[SparkException],
- awaitTimeout = 1 seconds,
- testBehaviorFor = 2 seconds)
+ awaitTimeout = 4 seconds,
+ testBehaviorFor = 6 seconds)
require(!q2.isActive) // should be inactive by the time the prev awaitAnyTerm returned
// All subsequent calls to awaitAnyTermination should throw the exception
@@ -236,7 +236,8 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with
@volatile var query: StreamExecution = null
try {
val df = ds.toDF
- val metadataRoot = Utils.createTempDir("streaming.metadata").getCanonicalPath
+ val metadataRoot =
+ Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
query = sqlContext
.streams
.startQuery(
@@ -293,8 +294,8 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with
if (withError) {
logDebug(s"Terminating query ${queryToStop.name} with error")
queryToStop.asInstanceOf[StreamExecution].logicalPlan.collect {
- case StreamingRelation(memoryStream, _) =>
- memoryStream.asInstanceOf[MemoryStream[Int]].addData(0)
+ case StreamingExecutionRelation(source, _) =>
+ source.asInstanceOf[MemoryStream[Int]].addData(0)
}
} else {
logDebug(s"Stopping query ${queryToStop.name}")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
index c1bab9b577..00efe21d39 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
@@ -17,6 +17,11 @@
package org.apache.spark.sql.streaming.test
+import java.util.concurrent.TimeUnit
+
+import scala.concurrent.duration._
+
+import org.mockito.Mockito._
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql._
@@ -27,22 +32,50 @@ import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.util.Utils
object LastOptions {
+
+ var mockStreamSourceProvider = mock(classOf[StreamSourceProvider])
+ var mockStreamSinkProvider = mock(classOf[StreamSinkProvider])
var parameters: Map[String, String] = null
var schema: Option[StructType] = null
var partitionColumns: Seq[String] = Nil
+
+ def clear(): Unit = {
+ parameters = null
+ schema = null
+ partitionColumns = null
+ reset(mockStreamSourceProvider)
+ reset(mockStreamSinkProvider)
+ }
}
/** Dummy provider: returns no-op source/sink and records options in [[LastOptions]]. */
class DefaultSource extends StreamSourceProvider with StreamSinkProvider {
+
+ private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)
+
+ override def sourceSchema(
+ sqlContext: SQLContext,
+ schema: Option[StructType],
+ providerName: String,
+ parameters: Map[String, String]): (String, StructType) = {
+ LastOptions.parameters = parameters
+ LastOptions.schema = schema
+ LastOptions.mockStreamSourceProvider.sourceSchema(sqlContext, schema, providerName, parameters)
+ ("dummySource", fakeSchema)
+ }
+
override def createSource(
sqlContext: SQLContext,
+ metadataPath: String,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): Source = {
LastOptions.parameters = parameters
LastOptions.schema = schema
+ LastOptions.mockStreamSourceProvider.createSource(
+ sqlContext, metadataPath, schema, providerName, parameters)
new Source {
- override def schema: StructType = StructType(StructField("a", IntegerType) :: Nil)
+ override def schema: StructType = fakeSchema
override def getOffset: Option[Offset] = Some(new LongOffset(0))
@@ -60,6 +93,7 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider {
partitionColumns: Seq[String]): Sink = {
LastOptions.parameters = parameters
LastOptions.partitionColumns = partitionColumns
+ LastOptions.mockStreamSinkProvider.createSink(sqlContext, parameters, partitionColumns)
new Sink {
override def addBatch(batchId: Long, data: DataFrame): Unit = {}
}
@@ -69,7 +103,8 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider {
class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with BeforeAndAfter {
import testImplicits._
- private def newMetadataDir = Utils.createTempDir("streaming.metadata").getCanonicalPath
+ private def newMetadataDir =
+ Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
after {
sqlContext.streams.active.foreach(_.stop())
@@ -112,7 +147,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
assert(LastOptions.parameters("opt2") == "2")
assert(LastOptions.parameters("opt3") == "3")
- LastOptions.parameters = null
+ LastOptions.clear()
df.write
.format("org.apache.spark.sql.streaming.test")
@@ -176,7 +211,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
assert(LastOptions.parameters("path") == "/test")
- LastOptions.parameters = null
+ LastOptions.clear()
df.write
.format("org.apache.spark.sql.streaming.test")
@@ -199,7 +234,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
assert(LastOptions.parameters("boolOpt") == "false")
assert(LastOptions.parameters("doubleOpt") == "6.7")
- LastOptions.parameters = null
+ LastOptions.clear()
df.write
.format("org.apache.spark.sql.streaming.test")
.option("intOpt", 56)
@@ -274,4 +309,63 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
assert(activeStreamNames.contains("name"))
sqlContext.streams.active.foreach(_.stop())
}
+
+ test("trigger") {
+ val df = sqlContext.read
+ .format("org.apache.spark.sql.streaming.test")
+ .stream("/test")
+
+ var q = df.write
+ .format("org.apache.spark.sql.streaming.test")
+ .option("checkpointLocation", newMetadataDir)
+ .trigger(ProcessingTime(10.seconds))
+ .startStream()
+ q.stop()
+
+ assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(10000))
+
+ q = df.write
+ .format("org.apache.spark.sql.streaming.test")
+ .option("checkpointLocation", newMetadataDir)
+ .trigger(ProcessingTime.create(100, TimeUnit.SECONDS))
+ .startStream()
+ q.stop()
+
+ assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(100000))
+ }
+
+ test("source metadataPath") {
+ LastOptions.clear()
+
+ val checkpointLocation = newMetadataDir
+
+ val df1 = sqlContext.read
+ .format("org.apache.spark.sql.streaming.test")
+ .stream()
+
+ val df2 = sqlContext.read
+ .format("org.apache.spark.sql.streaming.test")
+ .stream()
+
+ val q = df1.union(df2).write
+ .format("org.apache.spark.sql.streaming.test")
+ .option("checkpointLocation", checkpointLocation)
+ .trigger(ProcessingTime(10.seconds))
+ .startStream()
+ q.stop()
+
+ verify(LastOptions.mockStreamSourceProvider).createSource(
+ sqlContext,
+ checkpointLocation + "/sources/0",
+ None,
+ "org.apache.spark.sql.streaming.test",
+ Map.empty)
+
+ verify(LastOptions.mockStreamSourceProvider).createSource(
+ sqlContext,
+ checkpointLocation + "/sources/1",
+ None,
+ "org.apache.spark.sql.streaming.test",
+ Map.empty)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
index 7f31611383..8cf5dedabc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
@@ -29,8 +29,8 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext {
val inputData = MemoryStream[Int]
val df = inputData.toDF()
- val outputDir = Utils.createTempDir("stream.output").getCanonicalPath
- val checkpointDir = Utils.createTempDir("stream.checkpoint").getCanonicalPath
+ val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath
+ val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath
val query =
df.write
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
index 89de15acf5..73d1b1b1d5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
@@ -63,6 +63,7 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext {
format: String,
path: String,
schema: Option[StructType] = None): FileStreamSource = {
+ val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
val reader =
if (schema.isDefined) {
sqlContext.read.format(format).schema(schema.get)
@@ -71,8 +72,10 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext {
}
reader.stream(path)
.queryExecution.analyzed
- .collect { case StreamingRelation(s: FileStreamSource, _) => s }
- .head
+ .collect { case StreamingRelation(dataSource, _, _) =>
+ // There is only one source in our tests so just set sourceId to 0
+ dataSource.createSource(s"$checkpointLocation/sources/0").asInstanceOf[FileStreamSource]
+ }.head
}
val valueSchema = new StructType().add("value", StringType)
@@ -96,9 +99,9 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext {
reader.stream()
}
df.queryExecution.analyzed
- .collect { case StreamingRelation(s: FileStreamSource, _) => s }
- .head
- .schema
+ .collect { case StreamingRelation(dataSource, _, _) =>
+ dataSource.sourceSchema()
+ }.head._2
}
test("FileStreamSource schema: no path") {
@@ -202,8 +205,8 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext {
}
test("read from text files") {
- val src = Utils.createTempDir("streaming.src")
- val tmp = Utils.createTempDir("streaming.tmp")
+ val src = Utils.createTempDir(namePrefix = "streaming.src")
+ val tmp = Utils.createTempDir(namePrefix = "streaming.tmp")
val textSource = createFileStreamSource("text", src.getCanonicalPath)
val filtered = textSource.toDF().filter($"value" contains "keep")
@@ -224,8 +227,8 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext {
}
test("read from json files") {
- val src = Utils.createTempDir("streaming.src")
- val tmp = Utils.createTempDir("streaming.tmp")
+ val src = Utils.createTempDir(namePrefix = "streaming.src")
+ val tmp = Utils.createTempDir(namePrefix = "streaming.tmp")
val textSource = createFileStreamSource("json", src.getCanonicalPath, Some(valueSchema))
val filtered = textSource.toDF().filter($"value" contains "keep")
@@ -258,8 +261,8 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext {
}
test("read from json files with inferring schema") {
- val src = Utils.createTempDir("streaming.src")
- val tmp = Utils.createTempDir("streaming.tmp")
+ val src = Utils.createTempDir(namePrefix = "streaming.src")
+ val tmp = Utils.createTempDir(namePrefix = "streaming.tmp")
// Add a file so that we can infer its schema
stringToFile(new File(src, "existing"), "{'c': 'drop1'}\n{'c': 'keep2'}\n{'c': 'keep3'}")
@@ -279,8 +282,8 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext {
}
test("read from parquet files") {
- val src = Utils.createTempDir("streaming.src")
- val tmp = Utils.createTempDir("streaming.tmp")
+ val src = Utils.createTempDir(namePrefix = "streaming.src")
+ val tmp = Utils.createTempDir(namePrefix = "streaming.tmp")
val fileSource = createFileStreamSource("parquet", src.getCanonicalPath, Some(valueSchema))
val filtered = fileSource.toDF().filter($"value" contains "keep")
@@ -301,7 +304,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext {
}
test("file stream source without schema") {
- val src = Utils.createTempDir("streaming.src")
+ val src = Utils.createTempDir(namePrefix = "streaming.src")
// Only "text" doesn't need a schema
createFileStreamSource("text", src.getCanonicalPath)
@@ -318,8 +321,8 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext {
}
test("fault tolerance") {
- val src = Utils.createTempDir("streaming.src")
- val tmp = Utils.createTempDir("streaming.tmp")
+ val src = Utils.createTempDir(namePrefix = "streaming.src")
+ val tmp = Utils.createTempDir(namePrefix = "streaming.tmp")
val textSource = createFileStreamSource("text", src.getCanonicalPath)
val filtered = textSource.toDF().filter($"value" contains "keep")
@@ -338,7 +341,6 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext {
Utils.deleteRecursively(src)
Utils.deleteRecursively(tmp)
}
-
}
class FileStreamSourceStressTestSuite extends FileStreamSourceTest with SharedSQLContext {
@@ -346,8 +348,8 @@ class FileStreamSourceStressTestSuite extends FileStreamSourceTest with SharedSQ
import testImplicits._
test("file source stress test") {
- val src = Utils.createTempDir("streaming.src")
- val tmp = Utils.createTempDir("streaming.tmp")
+ val src = Utils.createTempDir(namePrefix = "streaming.src")
+ val tmp = Utils.createTempDir(namePrefix = "streaming.tmp")
val textSource = createFileStreamSource("text", src.getCanonicalPath)
val ds = textSource.toDS[String]().map(_.toInt + 1)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala
index 5a1bfb3a00..5b49a0a86a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils
/**
- * A stress test for streamign queries that read and write files. This test constists of
+ * A stress test for streaming queries that read and write files. This test consists of
* two threads:
* - one that writes out `numRecords` distinct integers to files of random sizes (the total
* number of records is fixed but each files size / creation time is random).
@@ -43,10 +43,10 @@ class FileStressSuite extends StreamTest with SharedSQLContext {
test("fault tolerance stress test") {
val numRecords = 10000
- val inputDir = Utils.createTempDir("stream.input").getCanonicalPath
- val stagingDir = Utils.createTempDir("stream.staging").getCanonicalPath
- val outputDir = Utils.createTempDir("stream.output").getCanonicalPath
- val checkpoint = Utils.createTempDir("stream.checkpoint").getCanonicalPath
+ val inputDir = Utils.createTempDir(namePrefix = "stream.input").getCanonicalPath
+ val stagingDir = Utils.createTempDir(namePrefix = "stream.staging").getCanonicalPath
+ val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath
+ val checkpoint = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath
@volatile
var continue = true
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala
new file mode 100644
index 0000000000..1f28340545
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.spark.sql.{AnalysisException, Row, StreamTest}
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.util.Utils
+
+class MemorySinkSuite extends StreamTest with SharedSQLContext {
+ import testImplicits._
+
+ test("registering as a table") {
+ val input = MemoryStream[Int]
+ val query = input.toDF().write
+ .format("memory")
+ .queryName("memStream")
+ .startStream()
+ input.addData(1, 2, 3)
+ query.processAllAvailable()
+
+ checkDataset(
+ sqlContext.table("memStream").as[Int],
+ 1, 2, 3)
+
+ input.addData(4, 5, 6)
+ query.processAllAvailable()
+ checkDataset(
+ sqlContext.table("memStream").as[Int],
+ 1, 2, 3, 4, 5, 6)
+
+ query.stop()
+ }
+
+ test("error when no name is specified") {
+ val error = intercept[AnalysisException] {
+ val input = MemoryStream[Int]
+ val query = input.toDF().write
+ .format("memory")
+ .startStream()
+ }
+
+ assert(error.message contains "queryName must be specified")
+ }
+
+ test("error if attempting to resume specific checkpoint") {
+ val location = Utils.createTempDir(namePrefix = "steaming.checkpoint").getCanonicalPath
+
+ val input = MemoryStream[Int]
+ val query = input.toDF().write
+ .format("memory")
+ .queryName("memStream")
+ .option("checkpointLocation", location)
+ .startStream()
+ input.addData(1, 2, 3)
+ query.processAllAvailable()
+ query.stop()
+
+ intercept[AnalysisException] {
+ input.toDF().write
+ .format("memory")
+ .queryName("memStream")
+ .option("checkpointLocation", location)
+ .startStream()
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index fbb1792596..2bd27c7efd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -17,9 +17,13 @@
package org.apache.spark.sql.streaming
-import org.apache.spark.sql.{Row, StreamTest}
+import org.scalatest.concurrent.Eventually._
+
+import org.apache.spark.sql.{DataFrame, Row, SQLContext, StreamTest}
import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
class StreamSuite extends StreamTest with SharedSQLContext {
@@ -81,4 +85,69 @@ class StreamSuite extends StreamTest with SharedSQLContext {
AddData(inputData, 1, 2, 3, 4),
CheckAnswer(2, 4))
}
+
+ test("DataFrame reuse") {
+ def assertDF(df: DataFrame) {
+ withTempDir { outputDir =>
+ withTempDir { checkpointDir =>
+ val query = df.write.format("parquet")
+ .option("checkpointLocation", checkpointDir.getAbsolutePath)
+ .startStream(outputDir.getAbsolutePath)
+ try {
+ query.processAllAvailable()
+ val outputDf = sqlContext.read.parquet(outputDir.getAbsolutePath).as[Long]
+ checkDataset[Long](outputDf, (0L to 10L).toArray: _*)
+ } finally {
+ query.stop()
+ }
+ }
+ }
+ }
+
+ val df = sqlContext.read.format(classOf[FakeDefaultSource].getName).stream()
+ assertDF(df)
+ assertDF(df)
+ }
+}
+
+/**
+ * A fake StreamSourceProvider thats creates a fake Source that cannot be reused.
+ */
+class FakeDefaultSource extends StreamSourceProvider {
+
+ private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)
+
+ override def sourceSchema(
+ sqlContext: SQLContext,
+ schema: Option[StructType],
+ providerName: String,
+ parameters: Map[String, String]): (String, StructType) = ("fakeSource", fakeSchema)
+
+ override def createSource(
+ sqlContext: SQLContext,
+ metadataPath: String,
+ schema: Option[StructType],
+ providerName: String,
+ parameters: Map[String, String]): Source = {
+ // Create a fake Source that emits 0 to 10.
+ new Source {
+ private var offset = -1L
+
+ override def schema: StructType = StructType(StructField("a", IntegerType) :: Nil)
+
+ override def getOffset: Option[Offset] = {
+ if (offset >= 10) {
+ None
+ } else {
+ offset += 1
+ Some(LongOffset(offset))
+ }
+ }
+
+ override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
+ val startOffset = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) + 1
+ sqlContext.range(startOffset, end.asInstanceOf[LongOffset].offset + 1).toDF("a")
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
new file mode 100644
index 0000000000..3af7c01e52
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.StreamTest
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.expressions.scala.typed
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
+
+object FailureSinglton {
+ var firstTime = true
+}
+
+class StreamingAggregationSuite extends StreamTest with SharedSQLContext {
+
+ import testImplicits._
+
+ test("simple count") {
+ val inputData = MemoryStream[Int]
+
+ val aggregated =
+ inputData.toDF()
+ .groupBy($"value")
+ .agg(count("*"))
+ .as[(Int, Long)]
+
+ testStream(aggregated)(
+ AddData(inputData, 3),
+ CheckLastBatch((3, 1)),
+ AddData(inputData, 3, 2),
+ CheckLastBatch((3, 2), (2, 1)),
+ StopStream,
+ StartStream,
+ AddData(inputData, 3, 2, 1),
+ CheckLastBatch((3, 3), (2, 2), (1, 1)),
+ // By default we run in new tuple mode.
+ AddData(inputData, 4, 4, 4, 4),
+ CheckLastBatch((4, 4))
+ )
+ }
+
+ test("multiple keys") {
+ val inputData = MemoryStream[Int]
+
+ val aggregated =
+ inputData.toDF()
+ .groupBy($"value", $"value" + 1)
+ .agg(count("*"))
+ .as[(Int, Int, Long)]
+
+ testStream(aggregated)(
+ AddData(inputData, 1, 2),
+ CheckLastBatch((1, 2, 1), (2, 3, 1)),
+ AddData(inputData, 1, 2),
+ CheckLastBatch((1, 2, 2), (2, 3, 2))
+ )
+ }
+
+ test("multiple aggregations") {
+ val inputData = MemoryStream[Int]
+
+ val aggregated =
+ inputData.toDF()
+ .groupBy($"value")
+ .agg(count("*") as 'count)
+ .groupBy($"value" % 2)
+ .agg(sum($"count"))
+ .as[(Int, Long)]
+
+ testStream(aggregated)(
+ AddData(inputData, 1, 2, 3, 4),
+ CheckLastBatch((0, 2), (1, 2)),
+ AddData(inputData, 1, 3, 5),
+ CheckLastBatch((1, 5))
+ )
+ }
+
+ testQuietly("midbatch failure") {
+ val inputData = MemoryStream[Int]
+ FailureSinglton.firstTime = true
+ val aggregated =
+ inputData.toDS()
+ .map { i =>
+ if (i == 4 && FailureSinglton.firstTime) {
+ FailureSinglton.firstTime = false
+ sys.error("injected failure")
+ }
+
+ i
+ }
+ .groupBy($"value")
+ .agg(count("*"))
+ .as[(Int, Long)]
+
+ testStream(aggregated)(
+ StartStream,
+ AddData(inputData, 1, 2, 3, 4),
+ ExpectFailure[SparkException](),
+ StartStream,
+ CheckLastBatch((1, 1), (2, 1), (3, 1), (4, 1))
+ )
+ }
+
+ test("typed aggregators") {
+ val inputData = MemoryStream[(String, Int)]
+ val aggregated = inputData.toDS().groupByKey(_._1).agg(typed.sumLong(_._2))
+
+ testStream(aggregated)(
+ AddData(inputData, ("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)),
+ CheckLastBatch(("a", 30), ("b", 3), ("c", 1))
+ )
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index 80a85a6615..7844d1b296 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -29,6 +29,7 @@ import org.scalatest.BeforeAndAfterAll
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
+import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.Filter
@@ -132,6 +133,27 @@ private[sql] trait SQLTestUtils
}
/**
+ * Drops functions after calling `f`. A function is represented by (functionName, isTemporary).
+ */
+ protected def withUserDefinedFunction(functions: (String, Boolean)*)(f: => Unit): Unit = {
+ try {
+ f
+ } catch {
+ case cause: Throwable => throw cause
+ } finally {
+ // If the test failed part way, we don't want to mask the failure by failing to remove
+ // temp tables that never got created.
+ try functions.foreach { case (functionName, isTemporary) =>
+ val withTemporary = if (isTemporary) "TEMPORARY" else ""
+ sqlContext.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName")
+ assert(
+ !sqlContext.sessionState.catalog.functionExists(FunctionIdentifier(functionName)),
+ s"Function $functionName should have been dropped. But, it still exists.")
+ }
+ }
+ }
+
+ /**
* Drops temporary table `tableName` after calling `f`.
*/
protected def withTempTable(tableNames: String*)(f: => Unit): Unit = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala
index d04783ecac..3498fe83d0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala
@@ -146,7 +146,6 @@ class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with
private def withListenerAdded(listener: ContinuousQueryListener)(body: => Unit): Unit = {
- @volatile var query: StreamExecution = null
try {
failAfter(1 minute) {
sqlContext.streams.addListener(listener)
@@ -212,7 +211,7 @@ class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with
case class QueryStatus(
active: Boolean,
- expection: Option[Exception],
+ exception: Option[Exception],
sourceStatuses: Array[SourceStatus],
sinkStatus: SinkStatus)
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
index a955314ba3..673a293ce2 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala
@@ -222,7 +222,7 @@ private[hive] class SparkExecuteStatementOperation(
val useIncrementalCollect =
hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean
if (useIncrementalCollect) {
- result.rdd.toLocalIterator
+ result.toLocalIterator.asScala
} else {
result.collect().iterator
}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
index 8e1ebe2937..eb49eabcb1 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
@@ -162,7 +162,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging {
runCliWithin(3.minute)(
"CREATE TABLE hive_test(key INT, val STRING);"
- -> "OK",
+ -> "",
"SHOW TABLES;"
-> "hive_test",
s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE hive_test;"
@@ -172,22 +172,22 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging {
"SELECT COUNT(*) FROM hive_test;"
-> "5",
"DROP TABLE hive_test;"
- -> "OK"
+ -> ""
)
}
test("Single command with -e") {
- runCliWithin(2.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK")
+ runCliWithin(2.minute, Seq("-e", "SHOW DATABASES;"))("" -> "")
}
test("Single command with --database") {
runCliWithin(2.minute)(
"CREATE DATABASE hive_test_db;"
- -> "OK",
+ -> "",
"USE hive_test_db;"
-> "",
"CREATE TABLE hive_test(key INT, val STRING);"
- -> "OK",
+ -> "",
"SHOW TABLES;"
-> "hive_test"
)
@@ -210,9 +210,9 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging {
"""CREATE TABLE t1(key string, val string)
|ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe';
""".stripMargin
- -> "OK",
+ -> "",
"CREATE TABLE sourceTable (key INT, val STRING);"
- -> "OK",
+ -> "",
s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE sourceTable;"
-> "OK",
"INSERT INTO TABLE t1 SELECT key, val FROM sourceTable;"
@@ -220,9 +220,9 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging {
"SELECT count(key) FROM t1;"
-> "5",
"DROP TABLE t1;"
- -> "OK",
+ -> "",
"DROP TABLE sourceTable;"
- -> "OK"
+ -> ""
)
}
@@ -230,7 +230,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging {
runCliWithin(timeout = 2.minute,
errorResponses = Seq("AnalysisException"))(
"select * from nonexistent_table;"
- -> "Error in query: Table not found: nonexistent_table;"
+ -> "Error in query: Table or View not found: nonexistent_table;"
)
}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
index 33af624cfd..a1268b8e94 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
@@ -491,46 +491,50 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
test("SPARK-11595 ADD JAR with input path having URL scheme") {
withJdbcStatement { statement =>
- val jarPath = "../hive/src/test/resources/TestUDTF.jar"
- val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath"
+ try {
+ val jarPath = "../hive/src/test/resources/TestUDTF.jar"
+ val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath"
- Seq(
- s"ADD JAR $jarURL",
- s"""CREATE TEMPORARY FUNCTION udtf_count2
- |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2'
- """.stripMargin
- ).foreach(statement.execute)
+ Seq(
+ s"ADD JAR $jarURL",
+ s"""CREATE TEMPORARY FUNCTION udtf_count2
+ |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2'
+ """.stripMargin
+ ).foreach(statement.execute)
- val rs1 = statement.executeQuery("DESCRIBE FUNCTION udtf_count2")
+ val rs1 = statement.executeQuery("DESCRIBE FUNCTION udtf_count2")
- assert(rs1.next())
- assert(rs1.getString(1) === "Function: udtf_count2")
+ assert(rs1.next())
+ assert(rs1.getString(1) === "Function: udtf_count2")
- assert(rs1.next())
- assertResult("Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2") {
- rs1.getString(1)
- }
+ assert(rs1.next())
+ assertResult("Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2") {
+ rs1.getString(1)
+ }
- assert(rs1.next())
- assert(rs1.getString(1) === "Usage: To be added.")
+ assert(rs1.next())
+ assert(rs1.getString(1) === "Usage: To be added.")
- val dataPath = "../hive/src/test/resources/data/files/kv1.txt"
+ val dataPath = "../hive/src/test/resources/data/files/kv1.txt"
- Seq(
- s"CREATE TABLE test_udtf(key INT, value STRING)",
- s"LOAD DATA LOCAL INPATH '$dataPath' OVERWRITE INTO TABLE test_udtf"
- ).foreach(statement.execute)
+ Seq(
+ s"CREATE TABLE test_udtf(key INT, value STRING)",
+ s"LOAD DATA LOCAL INPATH '$dataPath' OVERWRITE INTO TABLE test_udtf"
+ ).foreach(statement.execute)
- val rs2 = statement.executeQuery(
- "SELECT key, cc FROM test_udtf LATERAL VIEW udtf_count2(value) dd AS cc")
+ val rs2 = statement.executeQuery(
+ "SELECT key, cc FROM test_udtf LATERAL VIEW udtf_count2(value) dd AS cc")
- assert(rs2.next())
- assert(rs2.getInt(1) === 97)
- assert(rs2.getInt(2) === 500)
+ assert(rs2.next())
+ assert(rs2.getInt(1) === 97)
+ assert(rs2.getInt(2) === 500)
- assert(rs2.next())
- assert(rs2.getInt(1) === 97)
- assert(rs2.getInt(2) === 500)
+ assert(rs2.next())
+ assert(rs2.getInt(1) === 97)
+ assert(rs2.getInt(2) === 500)
+ } finally {
+ statement.executeQuery("DROP TEMPORARY FUNCTION udtf_count2")
+ }
}
}
@@ -565,24 +569,28 @@ class SingleSessionSuite extends HiveThriftJdbcTest {
},
{ statement =>
- val rs1 = statement.executeQuery("SET foo")
+ try {
+ val rs1 = statement.executeQuery("SET foo")
- assert(rs1.next())
- assert(rs1.getString(1) === "foo")
- assert(rs1.getString(2) === "bar")
+ assert(rs1.next())
+ assert(rs1.getString(1) === "foo")
+ assert(rs1.getString(2) === "bar")
- val rs2 = statement.executeQuery("DESCRIBE FUNCTION udtf_count2")
+ val rs2 = statement.executeQuery("DESCRIBE FUNCTION udtf_count2")
- assert(rs2.next())
- assert(rs2.getString(1) === "Function: udtf_count2")
+ assert(rs2.next())
+ assert(rs2.getString(1) === "Function: udtf_count2")
- assert(rs2.next())
- assertResult("Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2") {
- rs2.getString(1)
- }
+ assert(rs2.next())
+ assertResult("Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2") {
+ rs2.getString(1)
+ }
- assert(rs2.next())
- assert(rs2.getString(1) === "Usage: To be added.")
+ assert(rs2.next())
+ assert(rs2.getString(1) === "Usage: To be added.")
+ } finally {
+ statement.executeQuery("DROP TEMPORARY FUNCTION udtf_count2")
+ }
}
)
}
@@ -763,11 +771,15 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl
extraEnvironment = Map(
// Disables SPARK_TESTING to exclude log4j.properties in test directories.
"SPARK_TESTING" -> "0",
+ // But set SPARK_SQL_TESTING to make spark-class happy.
+ "SPARK_SQL_TESTING" -> "1",
// Points SPARK_PID_DIR to SPARK_HOME, otherwise only 1 Thrift server instance can be
// started at a time, which is not Jenkins friendly.
"SPARK_PID_DIR" -> pidDir.getCanonicalPath),
redirectStderr = true)
+ logInfo(s"COMMAND: $command")
+ logInfo(s"OUTPUT: $lines")
lines.split("\n").collectFirst {
case line if line.contains(LOG_FILE_MARK) => new File(line.drop(LOG_FILE_MARK.length))
}.getOrElse {
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 650797f768..989e68aebe 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -20,9 +20,10 @@ package org.apache.spark.sql.hive.execution
import java.io.File
import java.util.{Locale, TimeZone}
-import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.scalatest.BeforeAndAfter
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.internal.SQLConf
@@ -38,6 +39,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
private val originalLocale = Locale.getDefault
private val originalColumnBatchSize = TestHive.conf.columnBatchSize
private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning
+ private val originalConvertMetastoreOrc = TestHive.convertMetastoreOrc
def testCases: Seq[(String, File)] = {
hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f)
@@ -56,6 +58,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true)
// Use Hive hash expression instead of the native one
TestHive.sessionState.functionRegistry.unregisterFunction("hash")
+ // Ensures that the plans generation use metastore relation and not OrcRelation
+ // Was done because SqlBuilder does not work with plans having logical relation
+ TestHive.setConf(HiveContext.CONVERT_METASTORE_ORC, false)
RuleExecutor.resetTime()
}
@@ -66,6 +71,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
Locale.setDefault(originalLocale)
TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
+ TestHive.setConf(HiveContext.CONVERT_METASTORE_ORC, originalConvertMetastoreOrc)
TestHive.sessionState.functionRegistry.restore()
// For debugging dump some statistics about how much time was spent in various optimizer rules.
@@ -291,7 +297,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"compute_stats_empty_table",
"compute_stats_long",
"create_view_translate",
- "show_create_table_serde",
"show_tblproperties",
// Odd changes to output
@@ -344,7 +349,109 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
// These tests check the VIEW table definition, but Spark handles CREATE VIEW itself and
// generates different View Expanded Text.
"alter_view_as_select",
- "show_create_table_view"
+
+ // We don't support show create table commands in general
+ "show_create_table_alter",
+ "show_create_table_db_table",
+ "show_create_table_delimited",
+ "show_create_table_does_not_exist",
+ "show_create_table_index",
+ "show_create_table_partitioned",
+ "show_create_table_serde",
+ "show_create_table_view",
+
+ // These tests try to change how a table is bucketed, which we don't support
+ "alter4",
+ "sort_merge_join_desc_5",
+ "sort_merge_join_desc_6",
+ "sort_merge_join_desc_7",
+
+ // These tests try to create a table with bucketed columns, which we don't support
+ "auto_join32",
+ "auto_join_filters",
+ "auto_smb_mapjoin_14",
+ "ct_case_insensitive",
+ "explain_rearrange",
+ "groupby_sort_10",
+ "groupby_sort_2",
+ "groupby_sort_3",
+ "groupby_sort_4",
+ "groupby_sort_5",
+ "groupby_sort_7",
+ "groupby_sort_8",
+ "groupby_sort_9",
+ "groupby_sort_test_1",
+ "inputddl4",
+ "join_filters",
+ "join_nulls",
+ "join_nullsafe",
+ "load_dyn_part2",
+ "orc_empty_files",
+ "reduce_deduplicate",
+ "smb_mapjoin9",
+ "smb_mapjoin_1",
+ "smb_mapjoin_10",
+ "smb_mapjoin_13",
+ "smb_mapjoin_14",
+ "smb_mapjoin_15",
+ "smb_mapjoin_16",
+ "smb_mapjoin_17",
+ "smb_mapjoin_2",
+ "smb_mapjoin_21",
+ "smb_mapjoin_25",
+ "smb_mapjoin_3",
+ "smb_mapjoin_4",
+ "smb_mapjoin_5",
+ "smb_mapjoin_6",
+ "smb_mapjoin_7",
+ "smb_mapjoin_8",
+ "sort_merge_join_desc_1",
+ "sort_merge_join_desc_2",
+ "sort_merge_join_desc_3",
+ "sort_merge_join_desc_4",
+
+ // These tests try to create a table with skewed columns, which we don't support
+ "create_skewed_table1",
+ "skewjoinopt13",
+ "skewjoinopt18",
+ "skewjoinopt9",
+
+ // This test tries to create a table like with TBLPROPERTIES clause, which we don't support.
+ "create_like_tbl_props",
+
+ // Index commands are not supported
+ "drop_index",
+ "drop_index_removes_partition_dirs",
+ "alter_index",
+ "auto_sortmerge_join_1",
+ "auto_sortmerge_join_10",
+ "auto_sortmerge_join_11",
+ "auto_sortmerge_join_12",
+ "auto_sortmerge_join_13",
+ "auto_sortmerge_join_14",
+ "auto_sortmerge_join_15",
+ "auto_sortmerge_join_16",
+ "auto_sortmerge_join_2",
+ "auto_sortmerge_join_3",
+ "auto_sortmerge_join_4",
+ "auto_sortmerge_join_5",
+ "auto_sortmerge_join_6",
+ "auto_sortmerge_join_7",
+ "auto_sortmerge_join_8",
+ "auto_sortmerge_join_9",
+
+ // Macro commands are not supported
+ "macro",
+
+ // Create partitioned view is not supported
+ "create_like_view",
+ "describe_formatted_view_partitioned",
+
+ // This uses CONCATENATE, which we don't support
+ "alter_merge_2",
+
+ // TOUCH is not supported
+ "touch"
)
/**
@@ -359,10 +466,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"alias_casted_column",
"alter2",
"alter3",
- "alter4",
"alter5",
- "alter_index",
- "alter_merge_2",
"alter_partition_format_loc",
"alter_partition_with_whitelist",
"alter_rename_partition",
@@ -400,33 +504,14 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"auto_join3",
"auto_join30",
"auto_join31",
- "auto_join32",
"auto_join4",
"auto_join5",
"auto_join6",
"auto_join7",
"auto_join8",
"auto_join9",
- "auto_join_filters",
"auto_join_nulls",
"auto_join_reordering_values",
- "auto_smb_mapjoin_14",
- "auto_sortmerge_join_1",
- "auto_sortmerge_join_10",
- "auto_sortmerge_join_11",
- "auto_sortmerge_join_12",
- "auto_sortmerge_join_13",
- "auto_sortmerge_join_14",
- "auto_sortmerge_join_15",
- "auto_sortmerge_join_16",
- "auto_sortmerge_join_2",
- "auto_sortmerge_join_3",
- "auto_sortmerge_join_4",
- "auto_sortmerge_join_5",
- "auto_sortmerge_join_6",
- "auto_sortmerge_join_7",
- "auto_sortmerge_join_8",
- "auto_sortmerge_join_9",
"binary_constant",
"binarysortable_1",
"cast1",
@@ -455,16 +540,12 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"count",
"cp_mj_rc",
"create_insert_outputformat",
- "create_like_tbl_props",
- "create_like_view",
"create_nested_type",
- "create_skewed_table1",
"create_struct_table",
"create_view_translate",
"cross_join",
"cross_product_check_1",
"cross_product_check_2",
- "ct_case_insensitive",
"database_drop",
"database_location",
"database_properties",
@@ -481,15 +562,12 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"default_partition_name",
"delimiter",
"desc_non_existent_tbl",
- "describe_formatted_view_partitioned",
"diff_part_input_formats",
"disable_file_format_check",
"disallow_incompatible_type_change_off",
"distinct_stats",
"drop_database_removes_partition_dirs",
"drop_function",
- "drop_index",
- "drop_index_removes_partition_dirs",
"drop_multi_partitions",
"drop_partitions_filter",
"drop_partitions_filter2",
@@ -503,7 +581,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"escape_distributeby1",
"escape_orderby1",
"escape_sortby1",
- "explain_rearrange",
"fileformat_mix",
"fileformat_sequencefile",
"fileformat_text",
@@ -558,16 +635,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"groupby_neg_float",
"groupby_ppd",
"groupby_ppr",
- "groupby_sort_10",
- "groupby_sort_2",
- "groupby_sort_3",
- "groupby_sort_4",
- "groupby_sort_5",
"groupby_sort_6",
- "groupby_sort_7",
- "groupby_sort_8",
- "groupby_sort_9",
- "groupby_sort_test_1",
"having",
"implicit_cast1",
"index_serde",
@@ -622,7 +690,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"inputddl1",
"inputddl2",
"inputddl3",
- "inputddl4",
"inputddl6",
"inputddl7",
"inputddl8",
@@ -678,11 +745,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"join_array",
"join_casesensitive",
"join_empty",
- "join_filters",
"join_hive_626",
"join_map_ppr",
- "join_nulls",
- "join_nullsafe",
"join_rc",
"join_reorder2",
"join_reorder3",
@@ -706,7 +770,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"load_dyn_part13",
"load_dyn_part14",
"load_dyn_part14_win",
- "load_dyn_part2",
"load_dyn_part3",
"load_dyn_part4",
"load_dyn_part5",
@@ -717,7 +780,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"load_file_with_space_in_the_name",
"loadpart1",
"louter_join_ppr",
- "macro",
"mapjoin_distinct",
"mapjoin_filter_on_outerjoin",
"mapjoin_mapjoin",
@@ -760,7 +822,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"nullscript",
"optional_outer",
"orc_dictionary_threshold",
- "orc_empty_files",
"order",
"order2",
"outer_join_ppr",
@@ -816,7 +877,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"rcfile_null_value",
"rcfile_toleratecorruptions",
"rcfile_union",
- "reduce_deduplicate",
"reduce_deduplicate_exclude_gby",
"reduce_deduplicate_exclude_join",
"reduce_deduplicate_extended",
@@ -833,45 +893,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"serde_reported_schema",
"set_variable_sub",
"show_columns",
- "show_create_table_alter",
- "show_create_table_db_table",
- "show_create_table_delimited",
- "show_create_table_does_not_exist",
- "show_create_table_index",
- "show_create_table_partitioned",
- "show_create_table_serde",
"show_describe_func_quotes",
"show_functions",
"show_partitions",
"show_tblproperties",
- "skewjoinopt13",
- "skewjoinopt18",
- "skewjoinopt9",
- "smb_mapjoin9",
- "smb_mapjoin_1",
- "smb_mapjoin_10",
- "smb_mapjoin_13",
- "smb_mapjoin_14",
- "smb_mapjoin_15",
- "smb_mapjoin_16",
- "smb_mapjoin_17",
- "smb_mapjoin_2",
- "smb_mapjoin_21",
- "smb_mapjoin_25",
- "smb_mapjoin_3",
- "smb_mapjoin_4",
- "smb_mapjoin_5",
- "smb_mapjoin_6",
- "smb_mapjoin_7",
- "smb_mapjoin_8",
"sort",
- "sort_merge_join_desc_1",
- "sort_merge_join_desc_2",
- "sort_merge_join_desc_3",
- "sort_merge_join_desc_4",
- "sort_merge_join_desc_5",
- "sort_merge_join_desc_6",
- "sort_merge_join_desc_7",
"stats0",
"stats_aggregator_error_1",
"stats_empty_partition",
@@ -882,7 +908,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"timestamp_comparison",
"timestamp_lazy",
"timestamp_null",
- "touch",
"transform_ppr1",
"transform_ppr2",
"truncate_table",
diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml
index 22bad93e6d..61504becf1 100644
--- a/sql/hive/pom.xml
+++ b/sql/hive/pom.xml
@@ -225,49 +225,6 @@
<argLine>-da -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m</argLine>
</configuration>
</plugin>
- <plugin>
- <groupId>org.codehaus.mojo</groupId>
- <artifactId>build-helper-maven-plugin</artifactId>
- <executions>
- <execution>
- <id>add-default-sources</id>
- <phase>generate-sources</phase>
- <goals>
- <goal>add-source</goal>
- </goals>
- <configuration>
- <sources>
- <source>v${hive.version.short}/src/main/scala</source>
- <source>${project.build.directory/generated-sources/antlr</source>
- </sources>
- </configuration>
- </execution>
- </executions>
- </plugin>
-
- <!-- Deploy datanucleus jars to the spark/lib_managed/jars directory -->
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-dependency-plugin</artifactId>
- <executions>
- <execution>
- <id>copy-dependencies</id>
- <phase>package</phase>
- <goals>
- <goal>copy-dependencies</goal>
- </goals>
- <configuration>
- <!-- basedir is spark/sql/hive/ -->
- <outputDirectory>${basedir}/../../lib_managed/jars</outputDirectory>
- <overWriteReleases>false</overWriteReleases>
- <overWriteSnapshots>false</overWriteSnapshots>
- <overWriteIfNewer>true</overWriteIfNewer>
- <includeGroupIds>org.datanucleus</includeGroupIds>
- </configuration>
- </execution>
- </executions>
- </plugin>
-
</plugins>
</build>
</project>
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index ca3ce43591..505e5c0bb6 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -42,6 +42,7 @@ import org.apache.hadoop.util.VersionInfo
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.ConfigEntry
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
@@ -54,8 +55,7 @@ import org.apache.spark.sql.execution.ui.SQLListener
import org.apache.spark.sql.hive.client._
import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand}
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.internal.SQLConf.SQLConfEntry
-import org.apache.spark.sql.internal.SQLConf.SQLConfEntry._
+import org.apache.spark.sql.internal.SQLConf._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
@@ -86,7 +86,7 @@ class HiveContext private[hive](
@transient private[hive] val executionHive: HiveClientImpl,
@transient private[hive] val metadataHive: HiveClient,
isRootContext: Boolean,
- @transient private[sql] val hiveCatalog: HiveCatalog)
+ @transient private[sql] val hiveCatalog: HiveExternalCatalog)
extends SQLContext(sc, cacheManager, listener, isRootContext, hiveCatalog) with Logging {
self =>
@@ -98,7 +98,7 @@ class HiveContext private[hive](
execHive,
metaHive,
true,
- new HiveCatalog(metaHive))
+ new HiveExternalCatalog(metaHive))
}
def this(sc: SparkContext) = {
@@ -155,6 +155,13 @@ class HiveContext private[hive](
getConf(CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING)
/**
+ * When true, enables an experimental feature where metastore tables that use the Orc SerDe
+ * are automatically converted to use the Spark SQL ORC table scan, instead of the Hive
+ * SerDe.
+ */
+ protected[sql] def convertMetastoreOrc: Boolean = getConf(CONVERT_METASTORE_ORC)
+
+ /**
* When true, a table created by a Hive CTAS statement (no USING clause) will be
* converted to a data source table, using the data source set by spark.sql.sources.default.
* The table in CTAS statement will be converted when it meets any of the following conditions:
@@ -311,7 +318,7 @@ class HiveContext private[hive](
hiveconf.set(key, value)
}
- override private[sql] def setConf[T](entry: SQLConfEntry[T], value: T): Unit = {
+ override private[sql] def setConf[T](entry: ConfigEntry[T], value: T): Unit = {
setConf(entry.key, entry.stringConverter(value))
}
@@ -406,19 +413,19 @@ private[hive] object HiveContext extends Logging {
/** The version of hive used internally by Spark SQL. */
val hiveExecutionVersion: String = "1.2.1"
- val HIVE_METASTORE_VERSION = stringConf("spark.sql.hive.metastore.version",
- defaultValue = Some(hiveExecutionVersion),
- doc = "Version of the Hive metastore. Available options are " +
+ val HIVE_METASTORE_VERSION = SQLConfigBuilder("spark.sql.hive.metastore.version")
+ .doc("Version of the Hive metastore. Available options are " +
s"<code>0.12.0</code> through <code>$hiveExecutionVersion</code>.")
+ .stringConf
+ .createWithDefault(hiveExecutionVersion)
- val HIVE_EXECUTION_VERSION = stringConf(
- key = "spark.sql.hive.version",
- defaultValue = Some(hiveExecutionVersion),
- doc = "Version of Hive used internally by Spark SQL.")
+ val HIVE_EXECUTION_VERSION = SQLConfigBuilder("spark.sql.hive.version")
+ .doc("Version of Hive used internally by Spark SQL.")
+ .stringConf
+ .createWithDefault(hiveExecutionVersion)
- val HIVE_METASTORE_JARS = stringConf("spark.sql.hive.metastore.jars",
- defaultValue = Some("builtin"),
- doc = s"""
+ val HIVE_METASTORE_JARS = SQLConfigBuilder("spark.sql.hive.metastore.jars")
+ .doc(s"""
| Location of the jars that should be used to instantiate the HiveMetastoreClient.
| This property can be one of three options: "
| 1. "builtin"
@@ -429,44 +436,61 @@ private[hive] object HiveContext extends Logging {
| 2. "maven"
| Use Hive jars of specified version downloaded from Maven repositories.
| 3. A classpath in the standard format for both Hive and Hadoop.
- """.stripMargin)
- val CONVERT_METASTORE_PARQUET = booleanConf("spark.sql.hive.convertMetastoreParquet",
- defaultValue = Some(true),
- doc = "When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of " +
- "the built in support.")
+ """.stripMargin)
+ .stringConf
+ .createWithDefault("builtin")
- val CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING = booleanConf(
- "spark.sql.hive.convertMetastoreParquet.mergeSchema",
- defaultValue = Some(false),
- doc = "When true, also tries to merge possibly different but compatible Parquet schemas in " +
- "different Parquet data files. This configuration is only effective " +
- "when \"spark.sql.hive.convertMetastoreParquet\" is true.")
-
- val CONVERT_CTAS = booleanConf("spark.sql.hive.convertCTAS",
- defaultValue = Some(false),
- doc = "When true, a table created by a Hive CTAS statement (no USING clause) will be " +
+ val CONVERT_METASTORE_PARQUET = SQLConfigBuilder("spark.sql.hive.convertMetastoreParquet")
+ .doc("When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of " +
+ "the built in support.")
+ .booleanConf
+ .createWithDefault(true)
+
+ val CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING =
+ SQLConfigBuilder("spark.sql.hive.convertMetastoreParquet.mergeSchema")
+ .doc("When true, also tries to merge possibly different but compatible Parquet schemas in " +
+ "different Parquet data files. This configuration is only effective " +
+ "when \"spark.sql.hive.convertMetastoreParquet\" is true.")
+ .booleanConf
+ .createWithDefault(false)
+
+ val CONVERT_CTAS = SQLConfigBuilder("spark.sql.hive.convertCTAS")
+ .doc("When true, a table created by a Hive CTAS statement (no USING clause) will be " +
"converted to a data source table, using the data source set by spark.sql.sources.default.")
+ .booleanConf
+ .createWithDefault(false)
- val HIVE_METASTORE_SHARED_PREFIXES = stringSeqConf("spark.sql.hive.metastore.sharedPrefixes",
- defaultValue = Some(jdbcPrefixes),
- doc = "A comma separated list of class prefixes that should be loaded using the classloader " +
+ val CONVERT_METASTORE_ORC = SQLConfigBuilder("spark.sql.hive.convertMetastoreOrc")
+ .doc("When set to false, Spark SQL will use the Hive SerDe for ORC tables instead of " +
+ "the built in support.")
+ .booleanConf
+ .createWithDefault(true)
+
+ val HIVE_METASTORE_SHARED_PREFIXES = SQLConfigBuilder("spark.sql.hive.metastore.sharedPrefixes")
+ .doc("A comma separated list of class prefixes that should be loaded using the classloader " +
"that is shared between Spark SQL and a specific version of Hive. An example of classes " +
"that should be shared is JDBC drivers that are needed to talk to the metastore. Other " +
"classes that need to be shared are those that interact with classes that are already " +
"shared. For example, custom appenders that are used by log4j.")
+ .stringConf
+ .toSequence
+ .createWithDefault(jdbcPrefixes)
private def jdbcPrefixes = Seq(
"com.mysql.jdbc", "org.postgresql", "com.microsoft.sqlserver", "oracle.jdbc")
- val HIVE_METASTORE_BARRIER_PREFIXES = stringSeqConf("spark.sql.hive.metastore.barrierPrefixes",
- defaultValue = Some(Seq()),
- doc = "A comma separated list of class prefixes that should explicitly be reloaded for each " +
+ val HIVE_METASTORE_BARRIER_PREFIXES = SQLConfigBuilder("spark.sql.hive.metastore.barrierPrefixes")
+ .doc("A comma separated list of class prefixes that should explicitly be reloaded for each " +
"version of Hive that Spark SQL is communicating with. For example, Hive UDFs that are " +
"declared in a prefix that typically would be shared (i.e. <code>org.apache.spark.*</code>).")
-
- val HIVE_THRIFT_SERVER_ASYNC = booleanConf("spark.sql.hive.thriftServer.async",
- defaultValue = Some(true),
- doc = "When set to true, Hive Thrift server executes SQL queries in an asynchronous way.")
+ .stringConf
+ .toSequence
+ .createWithDefault(Nil)
+
+ val HIVE_THRIFT_SERVER_ASYNC = SQLConfigBuilder("spark.sql.hive.thriftServer.async")
+ .doc("When set to true, Hive Thrift server executes SQL queries in an asynchronous way.")
+ .booleanConf
+ .createWithDefault(true)
/**
* The version of the hive client that will be used to communicate with the metastore. Note that
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
index 0722fb02a8..f627384253 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
@@ -25,7 +25,6 @@ import org.apache.thrift.TException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.analysis.NoSuchItemException
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.hive.client.HiveClient
@@ -34,7 +33,7 @@ import org.apache.spark.sql.hive.client.HiveClient
* A persistent implementation of the system catalog using Hive.
* All public methods must be synchronized for thread-safety.
*/
-private[spark] class HiveCatalog(client: HiveClient) extends ExternalCatalog with Logging {
+private[spark] class HiveExternalCatalog(client: HiveClient) extends ExternalCatalog with Logging {
import ExternalCatalog._
// Exceptions thrown by the hive client that we would like to wrap
@@ -66,18 +65,16 @@ private[spark] class HiveCatalog(client: HiveClient) extends ExternalCatalog wit
try {
body
} catch {
- case e: NoSuchItemException =>
- throw new AnalysisException(e.getMessage)
case NonFatal(e) if isClientException(e) =>
throw new AnalysisException(e.getClass.getCanonicalName + ": " + e.getMessage)
}
}
private def requireDbMatches(db: String, table: CatalogTable): Unit = {
- if (table.name.database != Some(db)) {
+ if (table.identifier.database != Some(db)) {
throw new AnalysisException(
- s"Provided database $db does not much the one specified in the " +
- s"table definition (${table.name.database.getOrElse("n/a")})")
+ s"Provided database $db does not match the one specified in the " +
+ s"table definition (${table.identifier.database.getOrElse("n/a")})")
}
}
@@ -160,7 +157,8 @@ private[spark] class HiveCatalog(client: HiveClient) extends ExternalCatalog wit
}
override def renameTable(db: String, oldName: String, newName: String): Unit = withClient {
- val newTable = client.getTable(db, oldName).copy(name = TableIdentifier(newName, Some(db)))
+ val newTable = client.getTable(db, oldName)
+ .copy(identifier = TableIdentifier(newName, Some(db)))
client.alterTable(oldName, newTable)
}
@@ -173,7 +171,7 @@ private[spark] class HiveCatalog(client: HiveClient) extends ExternalCatalog wit
*/
override def alterTable(db: String, tableDefinition: CatalogTable): Unit = withClient {
requireDbMatches(db, tableDefinition)
- requireTableExists(db, tableDefinition.name.table)
+ requireTableExists(db, tableDefinition.identifier.table)
client.alterTable(tableDefinition)
}
@@ -181,6 +179,10 @@ private[spark] class HiveCatalog(client: HiveClient) extends ExternalCatalog wit
client.getTable(db, table)
}
+ override def getTableOption(db: String, table: String): Option[CatalogTable] = withClient {
+ client.getTableOption(db, table)
+ }
+
override def tableExists(db: String, table: String): Boolean = withClient {
client.getTableOption(db, table).isDefined
}
@@ -214,26 +216,7 @@ private[spark] class HiveCatalog(client: HiveClient) extends ExternalCatalog wit
parts: Seq[TablePartitionSpec],
ignoreIfNotExists: Boolean): Unit = withClient {
requireTableExists(db, table)
- // Note: Unfortunately Hive does not currently support `ignoreIfNotExists` so we
- // need to implement it here ourselves. This is currently somewhat expensive because
- // we make multiple synchronous calls to Hive for each partition we want to drop.
- val partsToDrop =
- if (ignoreIfNotExists) {
- parts.filter { spec =>
- try {
- getPartition(db, table, spec)
- true
- } catch {
- // Filter out the partitions that do not actually exist
- case _: AnalysisException => false
- }
- }
- } else {
- parts
- }
- if (partsToDrop.nonEmpty) {
- client.dropPartitions(db, table, partsToDrop)
- }
+ client.dropPartitions(db, table, parts, ignoreIfNotExists)
}
override def renamePartitions(
@@ -271,7 +254,12 @@ private[spark] class HiveCatalog(client: HiveClient) extends ExternalCatalog wit
override def createFunction(
db: String,
funcDefinition: CatalogFunction): Unit = withClient {
- client.createFunction(db, funcDefinition)
+ // Hive's metastore is case insensitive. However, Hive's createFunction does
+ // not normalize the function name (unlike the getFunction part). So,
+ // we are normalizing the function name.
+ val functionName = funcDefinition.identifier.funcName.toLowerCase
+ val functionIdentifier = funcDefinition.identifier.copy(funcName = functionName)
+ client.createFunction(db, funcDefinition.copy(identifier = functionIdentifier))
}
override def dropFunction(db: String, name: String): Unit = withClient {
@@ -282,14 +270,14 @@ private[spark] class HiveCatalog(client: HiveClient) extends ExternalCatalog wit
client.renameFunction(db, oldName, newName)
}
- override def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = withClient {
- client.alterFunction(db, funcDefinition)
- }
-
override def getFunction(db: String, funcName: String): CatalogFunction = withClient {
client.getFunction(db, funcName)
}
+ override def functionExists(db: String, funcName: String): Boolean = withClient {
+ client.functionExists(db, funcName)
+ }
+
override def listFunctions(db: String, pattern: String): Seq[String] = withClient {
client.listFunctions(db, pattern)
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index 589862c7c0..585befe378 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -450,9 +450,7 @@ private[hive] trait HiveInspectors {
if (o != null) {
val array = o.asInstanceOf[ArrayData]
val values = new java.util.ArrayList[Any](array.numElements())
- array.foreach(elementType, (_, e) => {
- values.add(wrapper(e))
- })
+ array.foreach(elementType, (_, e) => values.add(wrapper(e)))
values
} else {
null
@@ -468,9 +466,8 @@ private[hive] trait HiveInspectors {
if (o != null) {
val map = o.asInstanceOf[MapData]
val jmap = new java.util.HashMap[Any, Any](map.numElements())
- map.foreach(mt.keyType, mt.valueType, (k, v) => {
- jmap.put(keyWrapper(k), valueWrapper(v))
- })
+ map.foreach(mt.keyType, mt.valueType, (k, v) =>
+ jmap.put(keyWrapper(k), valueWrapper(v)))
jmap
} else {
null
@@ -587,9 +584,9 @@ private[hive] trait HiveInspectors {
case x: ListObjectInspector =>
val list = new java.util.ArrayList[Object]
val tpe = dataType.asInstanceOf[ArrayType].elementType
- a.asInstanceOf[ArrayData].foreach(tpe, (_, e) => {
+ a.asInstanceOf[ArrayData].foreach(tpe, (_, e) =>
list.add(wrap(e, x.getListElementObjectInspector, tpe))
- })
+ )
list
case x: MapObjectInspector =>
val keyType = dataType.asInstanceOf[MapType].keyType
@@ -599,10 +596,10 @@ private[hive] trait HiveInspectors {
// Some UDFs seem to assume we pass in a HashMap.
val hashMap = new java.util.HashMap[Any, Any](map.numElements())
- map.foreach(keyType, valueType, (k, v) => {
+ map.foreach(keyType, valueType, (k, v) =>
hashMap.put(wrap(k, x.getMapKeyObjectInspector, keyType),
wrap(v, x.getMapValueObjectInspector, valueType))
- })
+ )
hashMap
}
@@ -704,9 +701,8 @@ private[hive] trait HiveInspectors {
ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null)
} else {
val list = new java.util.ArrayList[Object]()
- value.asInstanceOf[ArrayData].foreach(dt, (_, e) => {
- list.add(wrap(e, listObjectInspector, dt))
- })
+ value.asInstanceOf[ArrayData].foreach(dt, (_, e) =>
+ list.add(wrap(e, listObjectInspector, dt)))
ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list)
}
case Literal(value, MapType(keyType, valueType, _)) =>
@@ -718,9 +714,8 @@ private[hive] trait HiveInspectors {
val map = value.asInstanceOf[MapData]
val jmap = new java.util.HashMap[Any, Any](map.numElements())
- map.foreach(keyType, valueType, (k, v) => {
- jmap.put(wrap(k, keyOI, keyType), wrap(v, valueOI, valueType))
- })
+ map.foreach(keyType, valueType, (k, v) =>
+ jmap.put(wrap(k, keyOI, keyType), wrap(v, valueOI, valueType)))
ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, jmap)
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index c7066d7363..ccc8345d73 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -25,7 +25,7 @@ import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.hive.common.StatsSetupConst
import org.apache.hadoop.hive.conf.HiveConf
-import org.apache.hadoop.hive.metastore.{TableType => HiveTableType, Warehouse}
+import org.apache.hadoop.hive.metastore.{TableType => HiveTableType}
import org.apache.hadoop.hive.metastore.api.FieldSchema
import org.apache.hadoop.hive.ql.metadata.{Table => HiveTable, _}
import org.apache.hadoop.hive.ql.plan.TableDesc
@@ -40,12 +40,13 @@ import org.apache.spark.sql.catalyst.parser.DataTypeParser
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
-import org.apache.spark.sql.execution.{datasources, FileRelation}
+import org.apache.spark.sql.execution.FileRelation
import org.apache.spark.sql.execution.datasources._
-import org.apache.spark.sql.execution.datasources.parquet.{DefaultSource, ParquetRelation}
+import org.apache.spark.sql.execution.datasources.parquet.{DefaultSource => ParquetDefaultSource, ParquetRelation}
import org.apache.spark.sql.hive.client._
import org.apache.spark.sql.hive.execution.HiveNativeCommand
-import org.apache.spark.sql.sources.{HadoopFsRelation, HDFSFileCatalog}
+import org.apache.spark.sql.hive.orc.{DefaultSource => OrcDefaultSource}
+import org.apache.spark.sql.sources.{FileFormat, HadoopFsRelation, HDFSFileCatalog}
import org.apache.spark.sql.types._
private[hive] case class HiveSerDe(
@@ -85,7 +86,18 @@ private[hive] object HiveSerDe {
HiveSerDe(
inputFormat = Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"),
outputFormat = Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"),
- serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")))
+ serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")),
+
+ "textfile" ->
+ HiveSerDe(
+ inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"),
+ outputFormat = Option("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")),
+
+ "avro" ->
+ HiveSerDe(
+ inputFormat = Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat"),
+ outputFormat = Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat"),
+ serde = Option("org.apache.hadoop.hive.serde2.avro.AvroSerDe")))
val key = source.toLowerCase match {
case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet"
@@ -102,7 +114,7 @@ private[hive] object HiveSerDe {
* Legacy catalog for interacting with the Hive metastore.
*
* This is still used for things like creating data source tables, but in the future will be
- * cleaned up to integrate more nicely with [[HiveCatalog]].
+ * cleaned up to integrate more nicely with [[HiveExternalCatalog]].
*/
private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveContext)
extends Logging {
@@ -124,8 +136,8 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte
private def getQualifiedTableName(t: CatalogTable): QualifiedTableName = {
QualifiedTableName(
- t.name.database.getOrElse(getCurrentDatabase).toLowerCase,
- t.name.table.toLowerCase)
+ t.identifier.database.getOrElse(getCurrentDatabase).toLowerCase,
+ t.identifier.table.toLowerCase)
}
/** A cache of Spark SQL data source tables that have been accessed. */
@@ -299,7 +311,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte
def newSparkSQLSpecificMetastoreTable(): CatalogTable = {
CatalogTable(
- name = TableIdentifier(tblName, Option(dbName)),
+ identifier = TableIdentifier(tblName, Option(dbName)),
tableType = tableType,
schema = Nil,
storage = CatalogStorageFormat(
@@ -319,7 +331,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte
assert(relation.partitionSchema.isEmpty)
CatalogTable(
- name = TableIdentifier(tblName, Option(dbName)),
+ identifier = TableIdentifier(tblName, Option(dbName)),
tableType = tableType,
storage = CatalogStorageFormat(
locationUri = Some(relation.location.paths.map(_.toUri.toString).head),
@@ -431,7 +443,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte
alias match {
// because hive use things like `_c0` to build the expanded text
// currently we cannot support view from "create view v1(c1) as ..."
- case None => SubqueryAlias(table.name.table, hive.parseSql(viewText))
+ case None => SubqueryAlias(table.identifier.table, hive.parseSql(viewText))
case Some(aliasText) => SubqueryAlias(aliasText, hive.parseSql(viewText))
}
} else {
@@ -440,58 +452,72 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte
}
}
- private def convertToParquetRelation(metastoreRelation: MetastoreRelation): LogicalRelation = {
- val metastoreSchema = StructType.fromAttributes(metastoreRelation.output)
- val mergeSchema = hive.convertMetastoreParquetWithSchemaMerging
-
- val parquetOptions = Map(
- ParquetRelation.MERGE_SCHEMA -> mergeSchema.toString,
- ParquetRelation.METASTORE_TABLE_NAME -> TableIdentifier(
- metastoreRelation.tableName,
- Some(metastoreRelation.databaseName)
- ).unquotedString
- )
- val tableIdentifier =
- QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName)
-
- def getCached(
- tableIdentifier: QualifiedTableName,
- pathsInMetastore: Seq[String],
- schemaInMetastore: StructType,
- partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = {
- cachedDataSourceTables.getIfPresent(tableIdentifier) match {
- case null => None // Cache miss
- case logical @ LogicalRelation(parquetRelation: HadoopFsRelation, _, _) =>
- // If we have the same paths, same schema, and same partition spec,
- // we will use the cached Parquet Relation.
- val useCached =
- parquetRelation.location.paths.map(_.toString).toSet == pathsInMetastore.toSet &&
- logical.schema.sameType(metastoreSchema) &&
- parquetRelation.partitionSpec == partitionSpecInMetastore.getOrElse {
- PartitionSpec(StructType(Nil), Array.empty[datasources.PartitionDirectory])
+ private def getCached(
+ tableIdentifier: QualifiedTableName,
+ metastoreRelation: MetastoreRelation,
+ schemaInMetastore: StructType,
+ expectedFileFormat: Class[_ <: FileFormat],
+ expectedBucketSpec: Option[BucketSpec],
+ partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = {
+
+ cachedDataSourceTables.getIfPresent(tableIdentifier) match {
+ case null => None // Cache miss
+ case logical @ LogicalRelation(relation: HadoopFsRelation, _, _) =>
+ val pathsInMetastore = metastoreRelation.table.storage.locationUri.toSeq
+ val cachedRelationFileFormatClass = relation.fileFormat.getClass
+
+ expectedFileFormat match {
+ case `cachedRelationFileFormatClass` =>
+ // If we have the same paths, same schema, and same partition spec,
+ // we will use the cached relation.
+ val useCached =
+ relation.location.paths.map(_.toString).toSet == pathsInMetastore.toSet &&
+ logical.schema.sameType(schemaInMetastore) &&
+ relation.bucketSpec == expectedBucketSpec &&
+ relation.partitionSpec == partitionSpecInMetastore.getOrElse {
+ PartitionSpec(StructType(Nil), Array.empty[PartitionDirectory])
+ }
+
+ if (useCached) {
+ Some(logical)
+ } else {
+ // If the cached relation is not updated, we invalidate it right away.
+ cachedDataSourceTables.invalidate(tableIdentifier)
+ None
}
-
- if (useCached) {
- Some(logical)
- } else {
- // If the cached relation is not updated, we invalidate it right away.
+ case _ =>
+ logWarning(
+ s"${metastoreRelation.databaseName}.${metastoreRelation.tableName} " +
+ s"should be stored as $expectedFileFormat. However, we are getting " +
+ s"a ${relation.fileFormat} from the metastore cache. This cached " +
+ s"entry will be invalidated.")
cachedDataSourceTables.invalidate(tableIdentifier)
None
- }
- case other =>
- logWarning(
- s"${metastoreRelation.databaseName}.${metastoreRelation.tableName} should be stored " +
- s"as Parquet. However, we are getting a $other from the metastore cache. " +
- s"This cached entry will be invalidated.")
- cachedDataSourceTables.invalidate(tableIdentifier)
- None
- }
+ }
+ case other =>
+ logWarning(
+ s"${metastoreRelation.databaseName}.${metastoreRelation.tableName} should be stored " +
+ s"as $expectedFileFormat. However, we are getting a $other from the metastore cache. " +
+ s"This cached entry will be invalidated.")
+ cachedDataSourceTables.invalidate(tableIdentifier)
+ None
}
+ }
+
+ private def convertToLogicalRelation(metastoreRelation: MetastoreRelation,
+ options: Map[String, String],
+ defaultSource: FileFormat,
+ fileFormatClass: Class[_ <: FileFormat],
+ fileType: String): LogicalRelation = {
+ val metastoreSchema = StructType.fromAttributes(metastoreRelation.output)
+ val tableIdentifier =
+ QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName)
+ val bucketSpec = None // We don't support hive bucketed tables, only ones we write out.
val result = if (metastoreRelation.hiveQlTable.isPartitioned) {
val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys)
val partitionColumnDataTypes = partitionSchema.map(_.dataType)
- // We're converting the entire table into ParquetRelation, so predicates to Hive metastore
+ // We're converting the entire table into HadoopFsRelation, so predicates to Hive metastore
// are empty.
val partitions = metastoreRelation.getHiveQlPartitions().map { p =>
val location = p.getLocation
@@ -504,54 +530,65 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte
val cached = getCached(
tableIdentifier,
- metastoreRelation.table.storage.locationUri.toSeq,
+ metastoreRelation,
metastoreSchema,
+ fileFormatClass,
+ bucketSpec,
Some(partitionSpec))
- val parquetRelation = cached.getOrElse {
+ val hadoopFsRelation = cached.getOrElse {
val paths = new Path(metastoreRelation.table.storage.locationUri.get) :: Nil
val fileCatalog = new MetaStoreFileCatalog(hive, paths, partitionSpec)
- val format = new DefaultSource()
- val inferredSchema = format.inferSchema(hive, parquetOptions, fileCatalog.allFiles())
- val mergedSchema = inferredSchema.map { inferred =>
- ParquetRelation.mergeMetastoreParquetSchema(metastoreSchema, inferred)
- }.getOrElse(metastoreSchema)
+ val inferredSchema = if (fileType.equals("parquet")) {
+ val inferredSchema = defaultSource.inferSchema(hive, options, fileCatalog.allFiles())
+ inferredSchema.map { inferred =>
+ ParquetRelation.mergeMetastoreParquetSchema(metastoreSchema, inferred)
+ }.getOrElse(metastoreSchema)
+ } else {
+ defaultSource.inferSchema(hive, options, fileCatalog.allFiles()).get
+ }
val relation = HadoopFsRelation(
sqlContext = hive,
location = fileCatalog,
partitionSchema = partitionSchema,
- dataSchema = mergedSchema,
- bucketSpec = None, // We don't support hive bucketed tables, only ones we write out.
- fileFormat = new DefaultSource(),
- options = parquetOptions)
+ dataSchema = inferredSchema,
+ bucketSpec = bucketSpec,
+ fileFormat = defaultSource,
+ options = options)
val created = LogicalRelation(relation)
cachedDataSourceTables.put(tableIdentifier, created)
created
}
- parquetRelation
+ hadoopFsRelation
} else {
val paths = Seq(metastoreRelation.hiveQlTable.getDataLocation.toString)
- val cached = getCached(tableIdentifier, paths, metastoreSchema, None)
- val parquetRelation = cached.getOrElse {
+ val cached = getCached(tableIdentifier,
+ metastoreRelation,
+ metastoreSchema,
+ fileFormatClass,
+ bucketSpec,
+ None)
+ val logicalRelation = cached.getOrElse {
val created =
LogicalRelation(
DataSource(
sqlContext = hive,
paths = paths,
userSpecifiedSchema = Some(metastoreRelation.schema),
- options = parquetOptions,
- className = "parquet").resolveRelation())
+ bucketSpec = bucketSpec,
+ options = options,
+ className = fileType).resolveRelation())
cachedDataSourceTables.put(tableIdentifier, created)
created
}
- parquetRelation
+ logicalRelation
}
result.copy(expectedOutputAttributes = Some(metastoreRelation.output))
}
@@ -561,6 +598,27 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte
* data source relations for better performance.
*/
object ParquetConversions extends Rule[LogicalPlan] {
+ private def shouldConvertMetastoreParquet(relation: MetastoreRelation): Boolean = {
+ relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") &&
+ hive.convertMetastoreParquet
+ }
+
+ private def convertToParquetRelation(relation: MetastoreRelation): LogicalRelation = {
+ val defaultSource = new ParquetDefaultSource()
+ val fileFormatClass = classOf[ParquetDefaultSource]
+
+ val mergeSchema = hive.convertMetastoreParquetWithSchemaMerging
+ val options = Map(
+ ParquetRelation.MERGE_SCHEMA -> mergeSchema.toString,
+ ParquetRelation.METASTORE_TABLE_NAME -> TableIdentifier(
+ relation.tableName,
+ Some(relation.databaseName)
+ ).unquotedString
+ )
+
+ convertToLogicalRelation(relation, options, defaultSource, fileFormatClass, "parquet")
+ }
+
override def apply(plan: LogicalPlan): LogicalPlan = {
if (!plan.resolved || plan.analyzed) {
return plan
@@ -570,22 +628,17 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte
// Write path
case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists)
// Inserting into partitioned table is not supported in Parquet data source (yet).
- if !r.hiveQlTable.isPartitioned && hive.convertMetastoreParquet &&
- r.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") =>
- val parquetRelation = convertToParquetRelation(r)
- InsertIntoTable(parquetRelation, partition, child, overwrite, ifNotExists)
+ if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreParquet(r) =>
+ InsertIntoTable(convertToParquetRelation(r), partition, child, overwrite, ifNotExists)
// Write path
case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists)
// Inserting into partitioned table is not supported in Parquet data source (yet).
- if !r.hiveQlTable.isPartitioned && hive.convertMetastoreParquet &&
- r.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") =>
- val parquetRelation = convertToParquetRelation(r)
- InsertIntoTable(parquetRelation, partition, child, overwrite, ifNotExists)
+ if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreParquet(r) =>
+ InsertIntoTable(convertToParquetRelation(r), partition, child, overwrite, ifNotExists)
// Read path
- case relation: MetastoreRelation if hive.convertMetastoreParquet &&
- relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") =>
+ case relation: MetastoreRelation if shouldConvertMetastoreParquet(relation) =>
val parquetRelation = convertToParquetRelation(relation)
SubqueryAlias(relation.alias.getOrElse(relation.tableName), parquetRelation)
}
@@ -593,6 +646,50 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte
}
/**
+ * When scanning Metastore ORC tables, convert them to ORC data source relations
+ * for better performance.
+ */
+ object OrcConversions extends Rule[LogicalPlan] {
+ private def shouldConvertMetastoreOrc(relation: MetastoreRelation): Boolean = {
+ relation.tableDesc.getSerdeClassName.toLowerCase.contains("orc") &&
+ hive.convertMetastoreOrc
+ }
+
+ private def convertToOrcRelation(relation: MetastoreRelation): LogicalRelation = {
+ val defaultSource = new OrcDefaultSource()
+ val fileFormatClass = classOf[OrcDefaultSource]
+ val options = Map[String, String]()
+
+ convertToLogicalRelation(relation, options, defaultSource, fileFormatClass, "orc")
+ }
+
+ override def apply(plan: LogicalPlan): LogicalPlan = {
+ if (!plan.resolved || plan.analyzed) {
+ return plan
+ }
+
+ plan transformUp {
+ // Write path
+ case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists)
+ // Inserting into partitioned table is not supported in Orc data source (yet).
+ if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreOrc(r) =>
+ InsertIntoTable(convertToOrcRelation(r), partition, child, overwrite, ifNotExists)
+
+ // Write path
+ case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists)
+ // Inserting into partitioned table is not supported in Orc data source (yet).
+ if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreOrc(r) =>
+ InsertIntoTable(convertToOrcRelation(r), partition, child, overwrite, ifNotExists)
+
+ // Read path
+ case relation: MetastoreRelation if shouldConvertMetastoreOrc(relation) =>
+ val orcRelation = convertToOrcRelation(relation)
+ SubqueryAlias(relation.alias.getOrElse(relation.tableName), orcRelation)
+ }
+ }
+ }
+
+ /**
* Creates any tables required for query execution.
* For example, because of a CREATE TABLE X AS statement.
*/
@@ -611,7 +708,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte
val QualifiedTableName(dbName, tblName) = getQualifiedTableName(table)
execution.CreateViewAsSelect(
- table.copy(name = TableIdentifier(tblName, Some(dbName))),
+ table.copy(identifier = TableIdentifier(tblName, Some(dbName))),
child,
allowExisting,
replace)
@@ -633,7 +730,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte
if (hive.convertCTAS && table.storage.serde.isEmpty) {
// Do the conversion when spark.sql.hive.convertCTAS is true and the query
// does not specify any storage format (file format and storage handler).
- if (table.name.database.isDefined) {
+ if (table.identifier.database.isDefined) {
throw new AnalysisException(
"Cannot specify database name in a CTAS statement " +
"when spark.sql.hive.convertCTAS is set to true.")
@@ -641,7 +738,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte
val mode = if (allowExisting) SaveMode.Ignore else SaveMode.ErrorIfExists
CreateTableUsingAsSelect(
- TableIdentifier(desc.name.table),
+ TableIdentifier(desc.identifier.table),
conf.defaultDataSourceName,
temporary = false,
Array.empty[String],
@@ -662,7 +759,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte
val QualifiedTableName(dbName, tblName) = getQualifiedTableName(table)
execution.CreateTableAsSelect(
- desc.copy(name = TableIdentifier(tblName, Some(dbName))),
+ desc.copy(identifier = TableIdentifier(tblName, Some(dbName))),
child,
allowExisting)
}
@@ -792,7 +889,7 @@ private[hive] case class MetastoreRelation(
// We start by constructing an API table as Hive performs several important transformations
// internally when converting an API table to a QL table.
val tTable = new org.apache.hadoop.hive.metastore.api.Table()
- tTable.setTableName(table.name.table)
+ tTable.setTableName(table.identifier.table)
tTable.setDbName(table.database)
val tableParameters = new java.util.HashMap[String, String]()
@@ -808,8 +905,13 @@ private[hive] case class MetastoreRelation(
val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor()
tTable.setSd(sd)
- sd.setCols(table.schema.map(toHiveColumn).asJava)
- tTable.setPartitionKeys(table.partitionColumns.map(toHiveColumn).asJava)
+
+ // Note: In Hive the schema and partition columns must be disjoint sets
+ val (partCols, schema) = table.schema.map(toHiveColumn).partition { c =>
+ table.partitionColumnNames.contains(c.getName)
+ }
+ sd.setCols(schema.asJava)
+ tTable.setPartitionKeys(partCols.asJava)
table.storage.locationUri.foreach(sd.setLocation)
table.storage.inputFormat.foreach(sd.setInputFormat)
@@ -916,7 +1018,10 @@ private[hive] case class MetastoreRelation(
val partitionKeys = table.partitionColumns.map(_.toAttribute)
/** Non-partitionKey attributes */
- val attributes = table.schema.map(_.toAttribute)
+ // TODO: just make this hold the schema itself, not just non-partition columns
+ val attributes = table.schema
+ .filter { c => !table.partitionColumnNames.contains(c.name) }
+ .map(_.toAttribute)
val output = attributes ++ partitionKeys
@@ -977,3 +1082,28 @@ private[hive] object HiveMetastoreTypes {
case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType)
}
}
+
+private[hive] case class CreateTableAsSelect(
+ tableDesc: CatalogTable,
+ child: LogicalPlan,
+ allowExisting: Boolean) extends UnaryNode with Command {
+
+ override def output: Seq[Attribute] = Seq.empty[Attribute]
+ override lazy val resolved: Boolean =
+ tableDesc.identifier.database.isDefined &&
+ tableDesc.schema.nonEmpty &&
+ tableDesc.storage.serde.isDefined &&
+ tableDesc.storage.inputFormat.isDefined &&
+ tableDesc.storage.outputFormat.isDefined &&
+ childrenResolved
+}
+
+private[hive] case class CreateViewAsSelect(
+ tableDesc: CatalogTable,
+ child: LogicalPlan,
+ allowExisting: Boolean,
+ replace: Boolean,
+ sql: String) extends UnaryNode with Command {
+ override def output: Seq[Attribute] = Seq.empty[Attribute]
+ override lazy val resolved: Boolean = false
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
deleted file mode 100644
index 6586b90377..0000000000
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ /dev/null
@@ -1,751 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive
-
-import java.util.Locale
-
-import scala.collection.JavaConverters._
-
-import org.apache.hadoop.hive.common.`type`.HiveDecimal
-import org.apache.hadoop.hive.conf.HiveConf
-import org.apache.hadoop.hive.conf.HiveConf.ConfVars
-import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry}
-import org.apache.hadoop.hive.ql.parse.EximUtil
-import org.apache.hadoop.hive.ql.session.SessionState
-import org.apache.hadoop.hive.serde.serdeConstants
-import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
-
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.catalog._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.parser._
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.execution.SparkQl
-import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper
-import org.apache.spark.sql.hive.execution._
-import org.apache.spark.sql.types._
-import org.apache.spark.sql.AnalysisException
-
-/**
- * Used when we need to start parsing the AST before deciding that we are going to pass the command
- * back for Hive to execute natively. Will be replaced with a native command that contains the
- * cmd string.
- */
-private[hive] case object NativePlaceholder extends LogicalPlan {
- override def children: Seq[LogicalPlan] = Seq.empty
- override def output: Seq[Attribute] = Seq.empty
-}
-
-private[hive] case class CreateTableAsSelect(
- tableDesc: CatalogTable,
- child: LogicalPlan,
- allowExisting: Boolean) extends UnaryNode with Command {
-
- override def output: Seq[Attribute] = Seq.empty[Attribute]
- override lazy val resolved: Boolean =
- tableDesc.name.database.isDefined &&
- tableDesc.schema.nonEmpty &&
- tableDesc.storage.serde.isDefined &&
- tableDesc.storage.inputFormat.isDefined &&
- tableDesc.storage.outputFormat.isDefined &&
- childrenResolved
-}
-
-private[hive] case class CreateViewAsSelect(
- tableDesc: CatalogTable,
- child: LogicalPlan,
- allowExisting: Boolean,
- replace: Boolean,
- sql: String) extends UnaryNode with Command {
- override def output: Seq[Attribute] = Seq.empty[Attribute]
- override lazy val resolved: Boolean = false
-}
-
-/** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */
-private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging {
- import ParseUtils._
- import ParserUtils._
-
- protected val nativeCommands = Seq(
- "TOK_ALTERDATABASE_OWNER",
- "TOK_ALTERDATABASE_PROPERTIES",
- "TOK_ALTERINDEX_PROPERTIES",
- "TOK_ALTERINDEX_REBUILD",
- "TOK_ALTERTABLE_ALTERPARTS",
- "TOK_ALTERTABLE_PARTITION",
- "TOK_ALTERVIEW_ADDPARTS",
- "TOK_ALTERVIEW_AS",
- "TOK_ALTERVIEW_DROPPARTS",
- "TOK_ALTERVIEW_PROPERTIES",
- "TOK_ALTERVIEW_RENAME",
-
- "TOK_CREATEINDEX",
- "TOK_CREATEMACRO",
- "TOK_CREATEROLE",
-
- "TOK_DESCDATABASE",
-
- "TOK_DROPFUNCTION",
- "TOK_DROPINDEX",
- "TOK_DROPMACRO",
- "TOK_DROPROLE",
- "TOK_DROPTABLE_PROPERTIES",
- "TOK_DROPVIEW",
- "TOK_DROPVIEW_PROPERTIES",
-
- "TOK_EXPORT",
-
- "TOK_GRANT",
- "TOK_GRANT_ROLE",
-
- "TOK_IMPORT",
-
- "TOK_LOAD",
-
- "TOK_LOCKTABLE",
-
- "TOK_MSCK",
-
- "TOK_REVOKE",
-
- "TOK_SHOW_COMPACTIONS",
- "TOK_SHOW_CREATETABLE",
- "TOK_SHOW_GRANT",
- "TOK_SHOW_ROLE_GRANT",
- "TOK_SHOW_ROLE_PRINCIPALS",
- "TOK_SHOW_ROLES",
- "TOK_SHOW_SET_ROLE",
- "TOK_SHOW_TABLESTATUS",
- "TOK_SHOW_TBLPROPERTIES",
- "TOK_SHOW_TRANSACTIONS",
- "TOK_SHOWCOLUMNS",
- "TOK_SHOWDATABASES",
- "TOK_SHOWINDEXES",
- "TOK_SHOWLOCKS",
- "TOK_SHOWPARTITIONS",
-
- "TOK_UNLOCKTABLE"
- )
-
- // Commands that we do not need to explain.
- protected val noExplainCommands = Seq(
- "TOK_DESCTABLE",
- "TOK_SHOWTABLES",
- "TOK_TRUNCATETABLE", // truncate table" is a NativeCommand, does not need to explain.
- "TOK_ALTERTABLE"
- ) ++ nativeCommands
-
- /**
- * Returns the HiveConf
- */
- private[this] def hiveConf: HiveConf = {
- var ss = SessionState.get()
- // SessionState is lazy initialization, it can be null here
- if (ss == null) {
- val original = Thread.currentThread().getContextClassLoader
- val conf = new HiveConf(classOf[SessionState])
- conf.setClassLoader(original)
- ss = new SessionState(conf)
- SessionState.start(ss)
- }
- ss.getConf
- }
-
- protected def getProperties(node: ASTNode): Seq[(String, String)] = node match {
- case Token("TOK_TABLEPROPLIST", list) =>
- list.map {
- case Token("TOK_TABLEPROPERTY", Token(key, Nil) :: Token(value, Nil) :: Nil) =>
- unquoteString(key) -> unquoteString(value)
- }
- }
-
- private def createView(
- view: ASTNode,
- viewNameParts: ASTNode,
- query: ASTNode,
- schema: Seq[CatalogColumn],
- properties: Map[String, String],
- allowExist: Boolean,
- replace: Boolean): CreateViewAsSelect = {
- val tableIdentifier = extractTableIdent(viewNameParts)
- val originalText = query.source
- val tableDesc = CatalogTable(
- name = tableIdentifier,
- tableType = CatalogTableType.VIRTUAL_VIEW,
- schema = schema,
- storage = CatalogStorageFormat(
- locationUri = None,
- inputFormat = None,
- outputFormat = None,
- serde = None,
- serdeProperties = Map.empty[String, String]
- ),
- properties = properties,
- viewOriginalText = Some(originalText),
- viewText = Some(originalText))
-
- // We need to keep the original SQL string so that if `spark.sql.nativeView` is
- // false, we can fall back to use hive native command later.
- // We can remove this when parser is configurable(can access SQLConf) in the future.
- val sql = view.source
- CreateViewAsSelect(tableDesc, nodeToPlan(query), allowExist, replace, sql)
- }
-
- /** Creates LogicalPlan for a given SQL string. */
- override def parsePlan(sql: String): LogicalPlan = {
- safeParse(sql, ParseDriver.parsePlan(sql, conf)) { ast =>
- if (nativeCommands.contains(ast.text)) {
- HiveNativeCommand(sql)
- } else {
- nodeToPlan(ast) match {
- case NativePlaceholder => HiveNativeCommand(sql)
- case plan => plan
- }
- }
- }
- }
-
- protected override def isNoExplainCommand(command: String): Boolean =
- noExplainCommands.contains(command)
-
- protected override def nodeToPlan(node: ASTNode): LogicalPlan = {
- node match {
- case Token("TOK_DFS", Nil) =>
- HiveNativeCommand(node.source + " " + node.remainder)
-
- case Token("TOK_ADDFILE", Nil) =>
- AddFile(node.remainder)
-
- case Token("TOK_ADDJAR", Nil) =>
- AddJar(node.remainder)
-
- // Special drop table that also uncaches.
- case Token("TOK_DROPTABLE", Token("TOK_TABNAME", tableNameParts) :: ifExists) =>
- val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".")
- DropTable(tableName, ifExists.nonEmpty)
-
- // Support "ANALYZE TABLE tableName COMPUTE STATISTICS noscan"
- case Token("TOK_ANALYZE",
- Token("TOK_TAB", Token("TOK_TABNAME", tableNameParts) :: partitionSpec) :: isNoscan) =>
- // Reference:
- // https://cwiki.apache.org/confluence/display/Hive/StatsDev#StatsDev-ExistingTables
- if (partitionSpec.nonEmpty) {
- // Analyze partitions will be treated as a Hive native command.
- NativePlaceholder
- } else if (isNoscan.isEmpty) {
- // If users do not specify "noscan", it will be treated as a Hive native command.
- NativePlaceholder
- } else {
- val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".")
- AnalyzeTable(tableName)
- }
-
- case view @ Token("TOK_ALTERVIEW", children) =>
- val Some(nameParts) :: maybeQuery :: _ =
- getClauses(Seq(
- "TOK_TABNAME",
- "TOK_QUERY",
- "TOK_ALTERVIEW_ADDPARTS",
- "TOK_ALTERVIEW_DROPPARTS",
- "TOK_ALTERVIEW_PROPERTIES",
- "TOK_ALTERVIEW_RENAME"), children)
-
- // if ALTER VIEW doesn't have query part, let hive to handle it.
- maybeQuery.map { query =>
- createView(view, nameParts, query, Nil, Map(), allowExist = false, replace = true)
- }.getOrElse(NativePlaceholder)
-
- case view @ Token("TOK_CREATEVIEW", children)
- if children.collect { case t @ Token("TOK_QUERY", _) => t }.nonEmpty =>
- val Seq(
- Some(viewNameParts),
- Some(query),
- maybeComment,
- replace,
- allowExisting,
- maybeProperties,
- maybeColumns,
- maybePartCols
- ) = getClauses(Seq(
- "TOK_TABNAME",
- "TOK_QUERY",
- "TOK_TABLECOMMENT",
- "TOK_ORREPLACE",
- "TOK_IFNOTEXISTS",
- "TOK_TABLEPROPERTIES",
- "TOK_TABCOLNAME",
- "TOK_VIEWPARTCOLS"), children)
-
- // If the view is partitioned, we let hive handle it.
- if (maybePartCols.isDefined) {
- NativePlaceholder
- } else {
- val schema = maybeColumns.map { cols =>
- // We can't specify column types when create view, so fill it with null first, and
- // update it after the schema has been resolved later.
- nodeToColumns(cols, lowerCase = true).map(_.copy(dataType = null))
- }.getOrElse(Seq.empty[CatalogColumn])
-
- val properties = scala.collection.mutable.Map.empty[String, String]
-
- maybeProperties.foreach {
- case Token("TOK_TABLEPROPERTIES", list :: Nil) =>
- properties ++= getProperties(list)
- }
-
- maybeComment.foreach {
- case Token("TOK_TABLECOMMENT", child :: Nil) =>
- val comment = unescapeSQLString(child.text)
- if (comment ne null) {
- properties += ("comment" -> comment)
- }
- }
-
- createView(view, viewNameParts, query, schema, properties.toMap,
- allowExisting.isDefined, replace.isDefined)
- }
-
- case Token("TOK_CREATETABLE", children)
- if children.collect { case t @ Token("TOK_QUERY", _) => t }.nonEmpty =>
- // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL
- val (
- Some(tableNameParts) ::
- _ /* likeTable */ ::
- externalTable ::
- Some(query) ::
- allowExisting +:
- _) =
- getClauses(
- Seq(
- "TOK_TABNAME",
- "TOK_LIKETABLE",
- "EXTERNAL",
- "TOK_QUERY",
- "TOK_IFNOTEXISTS",
- "TOK_TABLECOMMENT",
- "TOK_TABCOLLIST",
- "TOK_TABLEPARTCOLS", // Partitioned by
- "TOK_TABLEBUCKETS", // Clustered by
- "TOK_TABLESKEWED", // Skewed by
- "TOK_TABLEROWFORMAT",
- "TOK_TABLESERIALIZER",
- "TOK_FILEFORMAT_GENERIC",
- "TOK_TABLEFILEFORMAT", // User-provided InputFormat and OutputFormat
- "TOK_STORAGEHANDLER", // Storage handler
- "TOK_TABLELOCATION",
- "TOK_TABLEPROPERTIES"),
- children)
- val tableIdentifier = extractTableIdent(tableNameParts)
-
- // TODO add bucket support
- var tableDesc: CatalogTable = CatalogTable(
- name = tableIdentifier,
- tableType =
- if (externalTable.isDefined) {
- CatalogTableType.EXTERNAL_TABLE
- } else {
- CatalogTableType.MANAGED_TABLE
- },
- storage = CatalogStorageFormat(
- locationUri = None,
- inputFormat = None,
- outputFormat = None,
- serde = None,
- serdeProperties = Map.empty[String, String]
- ),
- schema = Seq.empty[CatalogColumn])
-
- // default storage type abbreviation (e.g. RCFile, ORC, PARQUET etc.)
- val defaultStorageType = hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTFILEFORMAT)
- // handle the default format for the storage type abbreviation
- val hiveSerDe = HiveSerDe.sourceToSerDe(defaultStorageType, hiveConf).getOrElse {
- HiveSerDe(
- inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"),
- outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat"))
- }
-
- tableDesc = tableDesc.withNewStorage(
- inputFormat = hiveSerDe.inputFormat.orElse(tableDesc.storage.inputFormat),
- outputFormat = hiveSerDe.outputFormat.orElse(tableDesc.storage.outputFormat),
- serde = hiveSerDe.serde.orElse(tableDesc.storage.serde))
-
- children.collect {
- case list @ Token("TOK_TABCOLLIST", _) =>
- val cols = nodeToColumns(list, lowerCase = true)
- if (cols != null) {
- tableDesc = tableDesc.copy(schema = cols)
- }
- case Token("TOK_TABLECOMMENT", child :: Nil) =>
- val comment = unescapeSQLString(child.text)
- // TODO support the sql text
- tableDesc = tableDesc.copy(viewText = Option(comment))
- case Token("TOK_TABLEPARTCOLS", list @ Token("TOK_TABCOLLIST", _) :: Nil) =>
- val cols = nodeToColumns(list.head, lowerCase = false)
- if (cols != null) {
- tableDesc = tableDesc.copy(partitionColumns = cols)
- }
- case Token("TOK_TABLEROWFORMAT", Token("TOK_SERDEPROPS", child :: Nil) :: Nil) =>
- val serdeParams = new java.util.HashMap[String, String]()
- child match {
- case Token("TOK_TABLEROWFORMATFIELD", rowChild1 :: rowChild2) =>
- val fieldDelim = unescapeSQLString (rowChild1.text)
- serdeParams.put(serdeConstants.FIELD_DELIM, fieldDelim)
- serdeParams.put(serdeConstants.SERIALIZATION_FORMAT, fieldDelim)
- if (rowChild2.length > 1) {
- val fieldEscape = unescapeSQLString (rowChild2.head.text)
- serdeParams.put(serdeConstants.ESCAPE_CHAR, fieldEscape)
- }
- case Token("TOK_TABLEROWFORMATCOLLITEMS", rowChild :: Nil) =>
- val collItemDelim = unescapeSQLString(rowChild.text)
- serdeParams.put(serdeConstants.COLLECTION_DELIM, collItemDelim)
- case Token("TOK_TABLEROWFORMATMAPKEYS", rowChild :: Nil) =>
- val mapKeyDelim = unescapeSQLString(rowChild.text)
- serdeParams.put(serdeConstants.MAPKEY_DELIM, mapKeyDelim)
- case Token("TOK_TABLEROWFORMATLINES", rowChild :: Nil) =>
- val lineDelim = unescapeSQLString(rowChild.text)
- if (!(lineDelim == "\n") && !(lineDelim == "10")) {
- throw new AnalysisException(
- s"LINES TERMINATED BY only supports newline '\\n' right now: $rowChild")
- }
- serdeParams.put(serdeConstants.LINE_DELIM, lineDelim)
- case Token("TOK_TABLEROWFORMATNULL", rowChild :: Nil) =>
- val nullFormat = unescapeSQLString(rowChild.text)
- // TODO support the nullFormat
- case _ => assert(false)
- }
- tableDesc = tableDesc.withNewStorage(
- serdeProperties = tableDesc.storage.serdeProperties ++ serdeParams.asScala)
- case Token("TOK_TABLELOCATION", child :: Nil) =>
- val location = EximUtil.relativeToAbsolutePath(hiveConf, unescapeSQLString(child.text))
- tableDesc = tableDesc.withNewStorage(locationUri = Option(location))
- case Token("TOK_TABLESERIALIZER", child :: Nil) =>
- tableDesc = tableDesc.withNewStorage(
- serde = Option(unescapeSQLString(child.children.head.text)))
- if (child.numChildren == 2) {
- // This is based on the readProps(..) method in
- // ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java:
- val serdeParams = child.children(1).children.head.children.map {
- case Token(_, Token(prop, Nil) :: valueNode) =>
- val value = valueNode.headOption
- .map(_.text)
- .map(unescapeSQLString)
- .orNull
- (unescapeSQLString(prop), value)
- }.toMap
- tableDesc = tableDesc.withNewStorage(
- serdeProperties = tableDesc.storage.serdeProperties ++ serdeParams)
- }
- case Token("TOK_FILEFORMAT_GENERIC", child :: Nil) =>
- child.text.toLowerCase(Locale.ENGLISH) match {
- case "orc" =>
- tableDesc = tableDesc.withNewStorage(
- inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"),
- outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat"))
- if (tableDesc.storage.serde.isEmpty) {
- tableDesc = tableDesc.withNewStorage(
- serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde"))
- }
-
- case "parquet" =>
- tableDesc = tableDesc.withNewStorage(
- inputFormat =
- Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"),
- outputFormat =
- Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"))
- if (tableDesc.storage.serde.isEmpty) {
- tableDesc = tableDesc.withNewStorage(
- serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"))
- }
-
- case "rcfile" =>
- tableDesc = tableDesc.withNewStorage(
- inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"),
- outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat"))
- if (tableDesc.storage.serde.isEmpty) {
- tableDesc = tableDesc.withNewStorage(
- serde =
- Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe"))
- }
-
- case "textfile" =>
- tableDesc = tableDesc.withNewStorage(
- inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"),
- outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat"))
-
- case "sequencefile" =>
- tableDesc = tableDesc.withNewStorage(
- inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"),
- outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat"))
-
- case "avro" =>
- tableDesc = tableDesc.withNewStorage(
- inputFormat =
- Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat"),
- outputFormat =
- Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat"))
- if (tableDesc.storage.serde.isEmpty) {
- tableDesc = tableDesc.withNewStorage(
- serde = Option("org.apache.hadoop.hive.serde2.avro.AvroSerDe"))
- }
-
- case _ =>
- throw new AnalysisException(
- s"Unrecognized file format in STORED AS clause: ${child.text}")
- }
-
- case Token("TOK_TABLESERIALIZER",
- Token("TOK_SERDENAME", Token(serdeName, Nil) :: otherProps) :: Nil) =>
- tableDesc = tableDesc.withNewStorage(serde = Option(unquoteString(serdeName)))
-
- otherProps match {
- case Token("TOK_TABLEPROPERTIES", list :: Nil) :: Nil =>
- tableDesc = tableDesc.withNewStorage(
- serdeProperties = tableDesc.storage.serdeProperties ++ getProperties(list))
- case _ =>
- }
-
- case Token("TOK_TABLEPROPERTIES", list :: Nil) =>
- tableDesc = tableDesc.copy(properties = tableDesc.properties ++ getProperties(list))
- case list @ Token("TOK_TABLEFILEFORMAT", _) =>
- tableDesc = tableDesc.withNewStorage(
- inputFormat = Option(unescapeSQLString(list.children.head.text)),
- outputFormat = Option(unescapeSQLString(list.children(1).text)))
- case Token("TOK_STORAGEHANDLER", _) =>
- throw new AnalysisException(
- "CREATE TABLE AS SELECT cannot be used for a non-native table")
- case _ => // Unsupported features
- }
-
- CreateTableAsSelect(tableDesc, nodeToPlan(query), allowExisting.isDefined)
-
- // If its not a "CTAS" like above then take it as a native command
- case Token("TOK_CREATETABLE", _) =>
- NativePlaceholder
-
- // Support "TRUNCATE TABLE table_name [PARTITION partition_spec]"
- case Token("TOK_TRUNCATETABLE", Token("TOK_TABLE_PARTITION", table) :: Nil) =>
- NativePlaceholder
-
- case _ =>
- super.nodeToPlan(node)
- }
- }
-
- protected override def nodeToDescribeFallback(node: ASTNode): LogicalPlan = NativePlaceholder
-
- protected override def nodeToTransformation(
- node: ASTNode,
- child: LogicalPlan): Option[logical.ScriptTransformation] = node match {
- case Token("TOK_SELEXPR",
- Token("TOK_TRANSFORM",
- Token("TOK_EXPLIST", inputExprs) ::
- Token("TOK_SERDE", inputSerdeClause) ::
- Token("TOK_RECORDWRITER", writerClause) ::
- // TODO: Need to support other types of (in/out)put
- Token(script, Nil) ::
- Token("TOK_SERDE", outputSerdeClause) ::
- Token("TOK_RECORDREADER", readerClause) ::
- outputClause) :: Nil) =>
-
- val (output, schemaLess) = outputClause match {
- case Token("TOK_ALIASLIST", aliases) :: Nil =>
- (aliases.map { case Token(name, Nil) =>
- AttributeReference(cleanIdentifier(name), StringType)() }, false)
- case Token("TOK_TABCOLLIST", attributes) :: Nil =>
- (attributes.map { case Token("TOK_TABCOL", Token(name, Nil) :: dataType :: Nil) =>
- AttributeReference(cleanIdentifier(name), nodeToDataType(dataType))() }, false)
- case Nil =>
- (List(AttributeReference("key", StringType)(),
- AttributeReference("value", StringType)()), true)
- case _ =>
- noParseRule("Transform", node)
- }
-
- type SerDeInfo = (
- Seq[(String, String)], // Input row format information
- Option[String], // Optional input SerDe class
- Seq[(String, String)], // Input SerDe properties
- Boolean // Whether to use default record reader/writer
- )
-
- def matchSerDe(clause: Seq[ASTNode]): SerDeInfo = clause match {
- case Token("TOK_SERDEPROPS", propsClause) :: Nil =>
- val rowFormat = propsClause.map {
- case Token(name, Token(value, Nil) :: Nil) => (name, value)
- }
- (rowFormat, None, Nil, false)
-
- case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil =>
- (Nil, Some(unescapeSQLString(serdeClass)), Nil, false)
-
- case Token("TOK_SERDENAME", Token(serdeClass, Nil) ::
- Token("TOK_TABLEPROPERTIES",
- Token("TOK_TABLEPROPLIST", propsClause) :: Nil) :: Nil) :: Nil =>
- val serdeProps = propsClause.map {
- case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) =>
- (unescapeSQLString(name), unescapeSQLString(value))
- }
-
- // SPARK-10310: Special cases LazySimpleSerDe
- // TODO Fully supports user-defined record reader/writer classes
- val unescapedSerDeClass = unescapeSQLString(serdeClass)
- val useDefaultRecordReaderWriter =
- unescapedSerDeClass == classOf[LazySimpleSerDe].getCanonicalName
- (Nil, Some(unescapedSerDeClass), serdeProps, useDefaultRecordReaderWriter)
-
- case Nil =>
- // Uses default TextRecordReader/TextRecordWriter, sets field delimiter here
- val serdeProps = Seq(serdeConstants.FIELD_DELIM -> "\t")
- (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), serdeProps, true)
- }
-
- val (inRowFormat, inSerdeClass, inSerdeProps, useDefaultRecordReader) =
- matchSerDe(inputSerdeClause)
-
- val (outRowFormat, outSerdeClass, outSerdeProps, useDefaultRecordWriter) =
- matchSerDe(outputSerdeClause)
-
- val unescapedScript = unescapeSQLString(script)
-
- // TODO Adds support for user-defined record reader/writer classes
- val recordReaderClass = if (useDefaultRecordReader) {
- Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDREADER))
- } else {
- None
- }
-
- val recordWriterClass = if (useDefaultRecordWriter) {
- Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDWRITER))
- } else {
- None
- }
-
- val schema = HiveScriptIOSchema(
- inRowFormat, outRowFormat,
- inSerdeClass, outSerdeClass,
- inSerdeProps, outSerdeProps,
- recordReaderClass, recordWriterClass,
- schemaLess)
-
- Some(
- logical.ScriptTransformation(
- inputExprs.map(nodeToExpr),
- unescapedScript,
- output,
- child, schema))
- case _ => None
- }
-
- protected override def nodeToGenerator(node: ASTNode): Generator = node match {
- case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) =>
- val functionInfo: FunctionInfo =
- Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse(
- sys.error(s"Couldn't find function $functionName"))
- val functionClassName = functionInfo.getFunctionClass.getName
- HiveGenericUDTF(
- functionName, new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr))
- case other => super.nodeToGenerator(node)
- }
-
- // This is based the getColumns methods in
- // ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java
- protected def nodeToColumns(node: ASTNode, lowerCase: Boolean): Seq[CatalogColumn] = {
- node.children.map(_.children).collect {
- case Token(rawColName, Nil) :: colTypeNode :: comment =>
- val colName = if (!lowerCase) rawColName else rawColName.toLowerCase
- CatalogColumn(
- name = cleanIdentifier(colName),
- dataType = nodeToTypeString(colTypeNode),
- nullable = true,
- comment.headOption.map(n => unescapeSQLString(n.text)))
- }
- }
-
- // This is based on the following methods in
- // ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java:
- // getTypeStringFromAST
- // getStructTypeStringFromAST
- // getUnionTypeStringFromAST
- protected def nodeToTypeString(node: ASTNode): String = node.tokenType match {
- case SparkSqlParser.TOK_LIST =>
- val listType :: Nil = node.children
- val listTypeString = nodeToTypeString(listType)
- s"${serdeConstants.LIST_TYPE_NAME}<$listTypeString>"
-
- case SparkSqlParser.TOK_MAP =>
- val keyType :: valueType :: Nil = node.children
- val keyTypeString = nodeToTypeString(keyType)
- val valueTypeString = nodeToTypeString(valueType)
- s"${serdeConstants.MAP_TYPE_NAME}<$keyTypeString,$valueTypeString>"
-
- case SparkSqlParser.TOK_STRUCT =>
- val typeNode = node.children.head
- require(typeNode.children.nonEmpty, "Struct must have one or more columns.")
- val structColStrings = typeNode.children.map { columnNode =>
- val Token(colName, Nil) :: colTypeNode :: Nil = columnNode.children
- cleanIdentifier(colName) + ":" + nodeToTypeString(colTypeNode)
- }
- s"${serdeConstants.STRUCT_TYPE_NAME}<${structColStrings.mkString(",")}>"
-
- case SparkSqlParser.TOK_UNIONTYPE =>
- val typeNode = node.children.head
- val unionTypesString = typeNode.children.map(nodeToTypeString).mkString(",")
- s"${serdeConstants.UNION_TYPE_NAME}<$unionTypesString>"
-
- case SparkSqlParser.TOK_CHAR =>
- val Token(size, Nil) :: Nil = node.children
- s"${serdeConstants.CHAR_TYPE_NAME}($size)"
-
- case SparkSqlParser.TOK_VARCHAR =>
- val Token(size, Nil) :: Nil = node.children
- s"${serdeConstants.VARCHAR_TYPE_NAME}($size)"
-
- case SparkSqlParser.TOK_DECIMAL =>
- val precisionAndScale = node.children match {
- case Token(precision, Nil) :: Token(scale, Nil) :: Nil =>
- precision + "," + scale
- case Token(precision, Nil) :: Nil =>
- precision + "," + HiveDecimal.USER_DEFAULT_SCALE
- case Nil =>
- HiveDecimal.USER_DEFAULT_PRECISION + "," + HiveDecimal.USER_DEFAULT_SCALE
- case _ =>
- noParseRule("Decimal", node)
- }
- s"${serdeConstants.DECIMAL_TYPE_NAME}($precisionAndScale)"
-
- // Simple data types.
- case SparkSqlParser.TOK_BOOLEAN => serdeConstants.BOOLEAN_TYPE_NAME
- case SparkSqlParser.TOK_TINYINT => serdeConstants.TINYINT_TYPE_NAME
- case SparkSqlParser.TOK_SMALLINT => serdeConstants.SMALLINT_TYPE_NAME
- case SparkSqlParser.TOK_INT => serdeConstants.INT_TYPE_NAME
- case SparkSqlParser.TOK_BIGINT => serdeConstants.BIGINT_TYPE_NAME
- case SparkSqlParser.TOK_FLOAT => serdeConstants.FLOAT_TYPE_NAME
- case SparkSqlParser.TOK_DOUBLE => serdeConstants.DOUBLE_TYPE_NAME
- case SparkSqlParser.TOK_STRING => serdeConstants.STRING_TYPE_NAME
- case SparkSqlParser.TOK_BINARY => serdeConstants.BINARY_TYPE_NAME
- case SparkSqlParser.TOK_DATE => serdeConstants.DATE_TYPE_NAME
- case SparkSqlParser.TOK_TIMESTAMP => serdeConstants.TIMESTAMP_TYPE_NAME
- case SparkSqlParser.TOK_INTERVAL_YEAR_MONTH => serdeConstants.INTERVAL_YEAR_MONTH_TYPE_NAME
- case SparkSqlParser.TOK_INTERVAL_DAY_TIME => serdeConstants.INTERVAL_DAY_TIME_TYPE_NAME
- case SparkSqlParser.TOK_DATETIME => serdeConstants.DATETIME_TYPE_NAME
- case _ => null
- }
-
-}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index aa44cba4b5..0cccc22e5a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -17,22 +17,39 @@
package org.apache.spark.sql.hive
+import scala.util.{Failure, Success, Try}
+import scala.util.control.NonFatal
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.ql.exec.{UDAF, UDF}
+import org.apache.hadoop.hive.ql.exec.{FunctionRegistry => HiveFunctionRegistry}
+import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF}
+
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.catalog.SessionCatalog
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
+import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, SessionCatalog}
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.execution.datasources.BucketSpec
+import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper
import org.apache.spark.sql.hive.client.HiveClient
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.Utils
-class HiveSessionCatalog(
- externalCatalog: HiveCatalog,
+private[sql] class HiveSessionCatalog(
+ externalCatalog: HiveExternalCatalog,
client: HiveClient,
context: HiveContext,
+ functionResourceLoader: FunctionResourceLoader,
+ functionRegistry: FunctionRegistry,
conf: SQLConf)
- extends SessionCatalog(externalCatalog, conf) {
+ extends SessionCatalog(externalCatalog, functionResourceLoader, functionRegistry, conf) {
override def setCurrentDatabase(db: String): Unit = {
super.setCurrentDatabase(db)
@@ -41,11 +58,11 @@ class HiveSessionCatalog(
override def lookupRelation(name: TableIdentifier, alias: Option[String]): LogicalPlan = {
val table = formatTableName(name.table)
- if (name.database.isDefined || !tempTables.containsKey(table)) {
+ if (name.database.isDefined || !tempTables.contains(table)) {
val newName = name.copy(table = table)
metastoreCatalog.lookupRelation(newName, alias)
} else {
- val relation = tempTables.get(table)
+ val relation = tempTables(table)
val tableWithQualifiers = SubqueryAlias(table, relation)
// If an alias was specified by the lookup, wrap the plan in a subquery so that
// attributes are properly qualified with this alias.
@@ -57,6 +74,11 @@ class HiveSessionCatalog(
// | Methods and fields for interacting with HiveMetastoreCatalog |
// ----------------------------------------------------------------
+ override def getDefaultDBPath(db: String): String = {
+ val defaultPath = context.hiveconf.getVar(HiveConf.ConfVars.METASTOREWAREHOUSE)
+ new Path(new Path(defaultPath), db + ".db").toString
+ }
+
// Catalog for handling data source tables. TODO: This really doesn't belong here since it is
// essentially a cache for metastore tables. However, it relies on a lot of session-specific
// things so it would be a lot of work to split its functionality between HiveSessionCatalog
@@ -64,6 +86,7 @@ class HiveSessionCatalog(
private val metastoreCatalog = new HiveMetastoreCatalog(client, context)
val ParquetConversions: Rule[LogicalPlan] = metastoreCatalog.ParquetConversions
+ val OrcConversions: Rule[LogicalPlan] = metastoreCatalog.OrcConversions
val CreateTables: Rule[LogicalPlan] = metastoreCatalog.CreateTables
val PreInsertionCasts: Rule[LogicalPlan] = metastoreCatalog.PreInsertionCasts
@@ -71,7 +94,7 @@ class HiveSessionCatalog(
metastoreCatalog.refreshTable(name)
}
- def invalidateTable(name: TableIdentifier): Unit = {
+ override def invalidateTable(name: TableIdentifier): Unit = {
metastoreCatalog.invalidateTable(name)
}
@@ -101,4 +124,129 @@ class HiveSessionCatalog(
metastoreCatalog.cachedDataSourceTables.getIfPresent(key)
}
+ override def makeFunctionBuilder(funcName: String, className: String): FunctionBuilder = {
+ makeFunctionBuilder(funcName, Utils.classForName(className))
+ }
+
+ /**
+ * Construct a [[FunctionBuilder]] based on the provided class that represents a function.
+ */
+ private def makeFunctionBuilder(name: String, clazz: Class[_]): FunctionBuilder = {
+ // When we instantiate hive UDF wrapper class, we may throw exception if the input
+ // expressions don't satisfy the hive UDF, such as type mismatch, input number
+ // mismatch, etc. Here we catch the exception and throw AnalysisException instead.
+ (children: Seq[Expression]) => {
+ try {
+ if (classOf[UDF].isAssignableFrom(clazz)) {
+ val udf = HiveSimpleUDF(name, new HiveFunctionWrapper(clazz.getName), children)
+ udf.dataType // Force it to check input data types.
+ udf
+ } else if (classOf[GenericUDF].isAssignableFrom(clazz)) {
+ val udf = HiveGenericUDF(name, new HiveFunctionWrapper(clazz.getName), children)
+ udf.dataType // Force it to check input data types.
+ udf
+ } else if (classOf[AbstractGenericUDAFResolver].isAssignableFrom(clazz)) {
+ val udaf = HiveUDAFFunction(name, new HiveFunctionWrapper(clazz.getName), children)
+ udaf.dataType // Force it to check input data types.
+ udaf
+ } else if (classOf[UDAF].isAssignableFrom(clazz)) {
+ val udaf = HiveUDAFFunction(
+ name,
+ new HiveFunctionWrapper(clazz.getName),
+ children,
+ isUDAFBridgeRequired = true)
+ udaf.dataType // Force it to check input data types.
+ udaf
+ } else if (classOf[GenericUDTF].isAssignableFrom(clazz)) {
+ val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), children)
+ udtf.elementTypes // Force it to check input data types.
+ udtf
+ } else {
+ throw new AnalysisException(s"No handler for Hive UDF '${clazz.getCanonicalName}'")
+ }
+ } catch {
+ case ae: AnalysisException =>
+ throw ae
+ case NonFatal(e) =>
+ val analysisException =
+ new AnalysisException(s"No handler for Hive UDF '${clazz.getCanonicalName}': $e")
+ analysisException.setStackTrace(e.getStackTrace)
+ throw analysisException
+ }
+ }
+ }
+
+ // We have a list of Hive built-in functions that we do not support. So, we will check
+ // Hive's function registry and lazily load needed functions into our own function registry.
+ // Those Hive built-in functions are
+ // assert_true, collect_list, collect_set, compute_stats, context_ngrams, create_union,
+ // current_user ,elt, ewah_bitmap, ewah_bitmap_and, ewah_bitmap_empty, ewah_bitmap_or, field,
+ // histogram_numeric, in_file, index, inline, java_method, map_keys, map_values,
+ // matchpath, ngrams, noop, noopstreaming, noopwithmap, noopwithmapstreaming,
+ // parse_url, parse_url_tuple, percentile, percentile_approx, posexplode, reflect, reflect2,
+ // regexp, sentences, stack, std, str_to_map, windowingtablefunction, xpath, xpath_boolean,
+ // xpath_double, xpath_float, xpath_int, xpath_long, xpath_number,
+ // xpath_short, and xpath_string.
+ override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
+ // TODO: Once lookupFunction accepts a FunctionIdentifier, we should refactor this method to
+ // if (super.functionExists(name)) {
+ // super.lookupFunction(name, children)
+ // } else {
+ // // This function is a Hive builtin function.
+ // ...
+ // }
+ Try(super.lookupFunction(name, children)) match {
+ case Success(expr) => expr
+ case Failure(error) =>
+ if (functionRegistry.functionExists(name)) {
+ // If the function actually exists in functionRegistry, it means that there is an
+ // error when we create the Expression using the given children.
+ // We need to throw the original exception.
+ throw error
+ } else {
+ // This function is not in functionRegistry, let's try to load it as a Hive's
+ // built-in function.
+ // Hive is case insensitive.
+ val functionName = name.toLowerCase
+ // TODO: This may not really work for current_user because current_user is not evaluated
+ // with session info.
+ // We do not need to use executionHive at here because we only load
+ // Hive's builtin functions, which do not need current db.
+ val functionInfo = {
+ try {
+ Option(HiveFunctionRegistry.getFunctionInfo(functionName)).getOrElse(
+ failFunctionLookup(name))
+ } catch {
+ // If HiveFunctionRegistry.getFunctionInfo throws an exception,
+ // we are failing to load a Hive builtin function, which means that
+ // the given function is not a Hive builtin function.
+ case NonFatal(e) => failFunctionLookup(name)
+ }
+ }
+ val className = functionInfo.getFunctionClass.getName
+ val builder = makeFunctionBuilder(functionName, className)
+ // Put this Hive built-in function to our function registry.
+ val info = new ExpressionInfo(className, functionName)
+ createTempFunction(functionName, info, builder, ignoreIfExists = false)
+ // Now, we need to create the Expression.
+ functionRegistry.lookupFunction(functionName, children)
+ }
+ }
+ }
+
+ // Pre-load a few commonly used Hive built-in functions.
+ HiveSessionCatalog.preloadedHiveBuiltinFunctions.foreach {
+ case (functionName, clazz) =>
+ val builder = makeFunctionBuilder(functionName, clazz)
+ val info = new ExpressionInfo(clazz.getCanonicalName, functionName)
+ createTempFunction(functionName, info, builder, ignoreIfExists = false)
+ }
+}
+
+private[sql] object HiveSessionCatalog {
+ // This is the list of Hive's built-in functions that are commonly used and we want to
+ // pre-load when we create the FunctionRegistry.
+ val preloadedHiveBuiltinFunctions =
+ ("collect_set", classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectSet]) ::
+ ("collect_list", classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList]) :: Nil
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
index caa7f296ed..b992fda18c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.execution.{python, SparkPlanner}
import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.hive.execution.HiveSqlParser
import org.apache.spark.sql.internal.{SessionState, SQLConf}
@@ -38,27 +39,25 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx)
* Internal catalog for managing table and database states.
*/
override lazy val catalog = {
- new HiveSessionCatalog(ctx.hiveCatalog, ctx.metadataHive, ctx, conf)
- }
-
- /**
- * Internal catalog for managing functions registered by the user.
- * Note that HiveUDFs will be overridden by functions registered in this context.
- */
- override lazy val functionRegistry: FunctionRegistry = {
- new HiveFunctionRegistry(FunctionRegistry.builtin.copy(), ctx.executionHive)
+ new HiveSessionCatalog(
+ ctx.hiveCatalog,
+ ctx.metadataHive,
+ ctx,
+ ctx.functionResourceLoader,
+ functionRegistry,
+ conf)
}
/**
* An analyzer that uses the Hive metastore.
*/
override lazy val analyzer: Analyzer = {
- new Analyzer(catalog, functionRegistry, conf) {
+ new Analyzer(catalog, conf) {
override val extendedResolutionRules =
catalog.ParquetConversions ::
+ catalog.OrcConversions ::
catalog.CreateTables ::
catalog.PreInsertionCasts ::
- python.ExtractPythonUDFs ::
PreInsertCastAndRename ::
DataSourceAnalysis ::
(if (conf.runSQLOnFile) new ResolveDataSource(ctx) :: Nil else Nil)
@@ -70,13 +69,14 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx)
/**
* Parser for HiveQl query texts.
*/
- override lazy val sqlParser: ParserInterface = new HiveQl(conf)
+ override lazy val sqlParser: ParserInterface = HiveSqlParser
/**
* Planner that takes into account Hive-specific strategies.
*/
- override lazy val planner: SparkPlanner = {
- new SparkPlanner(ctx.sparkContext, conf, experimentalMethods) with HiveStrategies {
+ override def planner: SparkPlanner = {
+ new SparkPlanner(ctx.sparkContext, conf, experimentalMethods.extraStrategies)
+ with HiveStrategies {
override val hiveContext = ctx
override def strategies: Seq[Strategy] = {
@@ -92,7 +92,7 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx)
DataSinks,
Scripts,
Aggregation,
- LeftSemiJoin,
+ ExistenceJoin,
EquiJoinSelection,
BasicOperators,
BroadcastNestedLoop,
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
index da910533d0..0d2a765a38 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
@@ -24,8 +24,6 @@ import scala.collection.JavaConverters._
import scala.language.implicitConversions
import scala.reflect.ClassTag
-import com.esotericsoftware.kryo.Kryo
-import com.esotericsoftware.kryo.io.{Input, Output}
import com.google.common.base.Objects
import org.apache.avro.Schema
import org.apache.hadoop.conf.Configuration
@@ -37,6 +35,8 @@ import org.apache.hadoop.hive.serde2.ColumnProjectionUtils
import org.apache.hadoop.hive.serde2.avro.{AvroGenericRecordWritable, AvroSerdeUtils}
import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector
import org.apache.hadoop.io.Writable
+import org.apache.hive.com.esotericsoftware.kryo.Kryo
+import org.apache.hive.com.esotericsoftware.kryo.io.{Input, Output}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.types.Decimal
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index f44937ec6f..010361a32e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -23,9 +23,8 @@ import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.command.{DescribeCommand => RunnableDescribeCommand, _}
-import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTableUsingAsSelect,
- DescribeCommand}
+import org.apache.spark.sql.execution.command.{DescribeCommand => _, _}
+import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTableUsingAsSelect, CreateTempTableUsingAsSelect, DescribeCommand}
import org.apache.spark.sql.hive.execution._
private[hive] trait HiveStrategies {
@@ -90,6 +89,11 @@ private[hive] trait HiveStrategies {
tableIdent, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath)
ExecutedCommand(cmd) :: Nil
+ case c: CreateTableUsingAsSelect if c.temporary =>
+ val cmd = CreateTempTableUsingAsSelect(
+ c.tableIdent, c.provider, c.partitionColumns, c.mode, c.options, c.child)
+ ExecutedCommand(cmd) :: Nil
+
case c: CreateTableUsingAsSelect =>
val cmd = CreateMetastoreDataSourceAsSelect(c.tableIdent, c.provider, c.partitionColumns,
c.bucketSpec, c.mode, c.options, c.child)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
index 80b24dc989..54afe9c2a3 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
@@ -34,6 +34,7 @@ import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf}
import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD}
import org.apache.spark.sql.catalyst.InternalRow
@@ -74,8 +75,7 @@ class HadoopTableReader(
math.max(sc.hiveconf.getInt("mapred.map.tasks", 1), sc.sparkContext.defaultMinPartitions)
}
- // TODO: set aws s3 credentials.
-
+ SparkHadoopUtil.get.appendS3AndSparkHadoopConfigurations(sc.sparkContext.conf, hiveExtraConf)
private val _broadcastedHiveConf =
sc.sparkContext.broadcast(new SerializableConfiguration(hiveExtraConf))
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala
index f4d30358ca..6f7e7bf451 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala
@@ -88,7 +88,7 @@ private[hive] trait HiveClient {
def dropTable(dbName: String, tableName: String, ignoreIfNotExists: Boolean): Unit
/** Alter a table whose name matches the one specified in `table`, assuming it exists. */
- final def alterTable(table: CatalogTable): Unit = alterTable(table.name.table, table)
+ final def alterTable(table: CatalogTable): Unit = alterTable(table.identifier.table, table)
/** Updates the given table with new metadata, optionally renaming the table. */
def alterTable(tableName: String, table: CatalogTable): Unit
@@ -120,16 +120,13 @@ private[hive] trait HiveClient {
ignoreIfExists: Boolean): Unit
/**
- * Drop one or many partitions in the given table.
- *
- * Note: Unfortunately, Hive does not currently provide a way to ignore this call if the
- * partitions do not already exist. The seemingly relevant flag `ifExists` in
- * [[org.apache.hadoop.hive.metastore.PartitionDropOptions]] is not read anywhere.
+ * Drop one or many partitions in the given table, assuming they exist.
*/
def dropPartitions(
db: String,
table: String,
- specs: Seq[ExternalCatalog.TablePartitionSpec]): Unit
+ specs: Seq[ExternalCatalog.TablePartitionSpec],
+ ignoreIfNotExists: Boolean): Unit
/**
* Rename one or many existing table partitions, assuming they exist.
@@ -232,6 +229,11 @@ private[hive] trait HiveClient {
/** Return an existing function in the database, or None if it doesn't exist. */
def getFunctionOption(db: String, name: String): Option[CatalogFunction]
+ /** Return whether a function exists in the specified database. */
+ final def functionExists(db: String, name: String): Boolean = {
+ getFunctionOption(db, name).isDefined
+ }
+
/** Return the names of all functions that match the given pattern in the database. */
def listFunctions(db: String, pattern: String): Seq[String]
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
index e4e15d13df..2a1fff92b5 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
@@ -26,10 +26,11 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.hive.cli.CliSessionState
import org.apache.hadoop.hive.conf.HiveConf
-import org.apache.hadoop.hive.metastore.{TableType => HiveTableType}
-import org.apache.hadoop.hive.metastore.api.{Database => HiveDatabase, FieldSchema, Function => HiveFunction, FunctionType, PrincipalType, ResourceUri}
+import org.apache.hadoop.hive.metastore.{PartitionDropOptions, TableType => HiveTableType}
+import org.apache.hadoop.hive.metastore.api.{Database => HiveDatabase, FieldSchema, Function => HiveFunction, FunctionType, PrincipalType, ResourceType, ResourceUri}
import org.apache.hadoop.hive.ql.Driver
-import org.apache.hadoop.hive.ql.metadata.{Hive, Partition => HivePartition, Table => HiveTable}
+import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable}
+import org.apache.hadoop.hive.ql.metadata.{Hive, HiveException}
import org.apache.hadoop.hive.ql.plan.AddPartitionDesc
import org.apache.hadoop.hive.ql.processors._
import org.apache.hadoop.hive.ql.session.SessionState
@@ -37,6 +38,7 @@ import org.apache.hadoop.security.UserGroupInformation
import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchPartitionException}
import org.apache.spark.sql.catalyst.catalog._
@@ -297,17 +299,22 @@ private[hive] class HiveClientImpl(
tableName: String): Option[CatalogTable] = withHiveState {
logDebug(s"Looking up $dbName.$tableName")
Option(client.getTable(dbName, tableName, false)).map { h =>
+ // Note: Hive separates partition columns and the schema, but for us the
+ // partition columns are part of the schema
+ val partCols = h.getPartCols.asScala.map(fromHiveColumn)
+ val schema = h.getCols.asScala.map(fromHiveColumn) ++ partCols
CatalogTable(
- name = TableIdentifier(h.getTableName, Option(h.getDbName)),
+ identifier = TableIdentifier(h.getTableName, Option(h.getDbName)),
tableType = h.getTableType match {
case HiveTableType.EXTERNAL_TABLE => CatalogTableType.EXTERNAL_TABLE
case HiveTableType.MANAGED_TABLE => CatalogTableType.MANAGED_TABLE
case HiveTableType.INDEX_TABLE => CatalogTableType.INDEX_TABLE
case HiveTableType.VIRTUAL_VIEW => CatalogTableType.VIRTUAL_VIEW
},
- schema = h.getCols.asScala.map(fromHiveColumn),
- partitionColumns = h.getPartCols.asScala.map(fromHiveColumn),
- sortColumns = Seq(),
+ schema = schema,
+ partitionColumnNames = partCols.map(_.name),
+ sortColumnNames = Seq(), // TODO: populate this
+ bucketColumnNames = h.getBucketCols.asScala,
numBuckets = h.getNumBuckets,
createTime = h.getTTable.getCreateTime.toLong * 1000,
lastAccessTime = h.getLastAccessTime.toLong * 1000,
@@ -365,9 +372,25 @@ private[hive] class HiveClientImpl(
override def dropPartitions(
db: String,
table: String,
- specs: Seq[ExternalCatalog.TablePartitionSpec]): Unit = withHiveState {
+ specs: Seq[ExternalCatalog.TablePartitionSpec],
+ ignoreIfNotExists: Boolean): Unit = withHiveState {
// TODO: figure out how to drop multiple partitions in one call
- specs.foreach { s => client.dropPartition(db, table, s.values.toList.asJava, true) }
+ val hiveTable = client.getTable(db, table, true /* throw exception */)
+ specs.foreach { s =>
+ // The provided spec here can be a partial spec, i.e. it will match all partitions
+ // whose specs are supersets of this partial spec. E.g. If a table has partitions
+ // (b='1', c='1') and (b='1', c='2'), a partial spec of (b='1') will match both.
+ val matchingParts = client.getPartitions(hiveTable, s.asJava).asScala
+ if (matchingParts.isEmpty && !ignoreIfNotExists) {
+ throw new AnalysisException(
+ s"partition to drop '$s' does not exist in table '$table' database '$db'")
+ }
+ matchingParts.foreach { hivePartition =>
+ val dropOptions = new PartitionDropOptions
+ dropOptions.ifExists = ignoreIfNotExists
+ client.dropPartition(db, table, hivePartition.getValues, dropOptions)
+ }
+ }
}
override def renamePartitions(
@@ -544,19 +567,24 @@ private[hive] class HiveClientImpl(
}
override def renameFunction(db: String, oldName: String, newName: String): Unit = withHiveState {
- val catalogFunc = getFunction(db, oldName).copy(name = FunctionIdentifier(newName, Some(db)))
+ val catalogFunc = getFunction(db, oldName)
+ .copy(identifier = FunctionIdentifier(newName, Some(db)))
val hiveFunc = toHiveFunction(catalogFunc, db)
client.alterFunction(db, oldName, hiveFunc)
}
override def alterFunction(db: String, func: CatalogFunction): Unit = withHiveState {
- client.alterFunction(db, func.name.funcName, toHiveFunction(func, db))
+ client.alterFunction(db, func.identifier.funcName, toHiveFunction(func, db))
}
override def getFunctionOption(
db: String,
name: String): Option[CatalogFunction] = withHiveState {
- Option(client.getFunction(db, name)).map(fromHiveFunction)
+ try {
+ Option(client.getFunction(db, name)).map(fromHiveFunction)
+ } catch {
+ case he: HiveException => None
+ }
}
override def listFunctions(db: String, pattern: String): Seq[String] = withHiveState {
@@ -610,20 +638,32 @@ private[hive] class HiveClientImpl(
.asInstanceOf[Class[_ <: org.apache.hadoop.hive.ql.io.HiveOutputFormat[_, _]]]
private def toHiveFunction(f: CatalogFunction, db: String): HiveFunction = {
+ val resourceUris = f.resources.map { case (resourceType, resourcePath) =>
+ new ResourceUri(ResourceType.valueOf(resourceType.toUpperCase), resourcePath)
+ }
new HiveFunction(
- f.name.funcName,
+ f.identifier.funcName,
db,
f.className,
null,
PrincipalType.USER,
(System.currentTimeMillis / 1000).toInt,
FunctionType.JAVA,
- List.empty[ResourceUri].asJava)
+ resourceUris.asJava)
}
private def fromHiveFunction(hf: HiveFunction): CatalogFunction = {
val name = FunctionIdentifier(hf.getFunctionName, Option(hf.getDbName))
- new CatalogFunction(name, hf.getClassName)
+ val resources = hf.getResourceUris.asScala.map { uri =>
+ val resourceType = uri.getResourceType() match {
+ case ResourceType.ARCHIVE => "archive"
+ case ResourceType.FILE => "file"
+ case ResourceType.JAR => "jar"
+ case r => throw new AnalysisException(s"Unknown resource type: $r")
+ }
+ (resourceType, uri.getUri())
+ }
+ new CatalogFunction(name, hf.getClassName, resources)
}
private def toHiveColumn(c: CatalogColumn): FieldSchema = {
@@ -639,16 +679,38 @@ private[hive] class HiveClientImpl(
}
private def toHiveTable(table: CatalogTable): HiveTable = {
- val hiveTable = new HiveTable(table.database, table.name.table)
+ val hiveTable = new HiveTable(table.database, table.identifier.table)
+ // For EXTERNAL_TABLE, we also need to set EXTERNAL field in the table properties.
+ // Otherwise, Hive metastore will change the table to a MANAGED_TABLE.
+ // (metastore/src/java/org/apache/hadoop/hive/metastore/ObjectStore.java#L1095-L1105)
hiveTable.setTableType(table.tableType match {
- case CatalogTableType.EXTERNAL_TABLE => HiveTableType.EXTERNAL_TABLE
- case CatalogTableType.MANAGED_TABLE => HiveTableType.MANAGED_TABLE
+ case CatalogTableType.EXTERNAL_TABLE =>
+ hiveTable.setProperty("EXTERNAL", "TRUE")
+ HiveTableType.EXTERNAL_TABLE
+ case CatalogTableType.MANAGED_TABLE =>
+ HiveTableType.MANAGED_TABLE
case CatalogTableType.INDEX_TABLE => HiveTableType.INDEX_TABLE
case CatalogTableType.VIRTUAL_VIEW => HiveTableType.VIRTUAL_VIEW
})
- hiveTable.setFields(table.schema.map(toHiveColumn).asJava)
- hiveTable.setPartCols(table.partitionColumns.map(toHiveColumn).asJava)
+ // Note: In Hive the schema and partition columns must be disjoint sets
+ val (partCols, schema) = table.schema.map(toHiveColumn).partition { c =>
+ table.partitionColumnNames.contains(c.getName)
+ }
+ if (table.schema.isEmpty) {
+ // This is a hack to preserve existing behavior. Before Spark 2.0, we do not
+ // set a default serde here (this was done in Hive), and so if the user provides
+ // an empty schema Hive would automatically populate the schema with a single
+ // field "col". However, after SPARK-14388, we set the default serde to
+ // LazySimpleSerde so this implicit behavior no longer happens. Therefore,
+ // we need to do it in Spark ourselves.
+ hiveTable.setFields(
+ Seq(new FieldSchema("col", "array<string>", "from deserializer")).asJava)
+ } else {
+ hiveTable.setFields(schema.asJava)
+ }
+ hiveTable.setPartCols(partCols.asJava)
// TODO: set sort columns here too
+ hiveTable.setBucketCols(table.bucketColumnNames.asJava)
hiveTable.setOwner(conf.getUser)
hiveTable.setNumBuckets(table.numBuckets)
hiveTable.setCreateTime((table.createTime / 1000).toInt)
@@ -656,9 +718,11 @@ private[hive] class HiveClientImpl(
table.storage.locationUri.foreach { loc => shim.setDataLocation(hiveTable, loc) }
table.storage.inputFormat.map(toInputFormat).foreach(hiveTable.setInputFormatClass)
table.storage.outputFormat.map(toOutputFormat).foreach(hiveTable.setOutputFormatClass)
- table.storage.serde.foreach(hiveTable.setSerializationLib)
+ hiveTable.setSerializationLib(
+ table.storage.serde.getOrElse("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"))
table.storage.serdeProperties.foreach { case (k, v) => hiveTable.setSerdeParam(k, v) }
table.properties.foreach { case (k, v) => hiveTable.setProperty(k, v) }
+ table.comment.foreach { c => hiveTable.setProperty("comment", c) }
table.viewOriginalText.foreach { t => hiveTable.setViewOriginalText(t) }
table.viewText.foreach { t => hiveTable.setViewExpandedText(t) }
hiveTable
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala
index 5a61eef0f2..29f7dc2997 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala
@@ -38,7 +38,7 @@ case class CreateTableAsSelect(
allowExisting: Boolean)
extends RunnableCommand {
- private val tableIdentifier = tableDesc.name
+ private val tableIdentifier = tableDesc.identifier
override def children: Seq[LogicalPlan] = Seq(query)
@@ -93,6 +93,8 @@ case class CreateTableAsSelect(
}
override def argString: String = {
- s"[Database:${tableDesc.database}}, TableName: ${tableDesc.name.table}, InsertIntoHiveTable]"
+ s"[Database:${tableDesc.database}}, " +
+ s"TableName: ${tableDesc.identifier.table}, " +
+ s"InsertIntoHiveTable]"
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala
index 9ff520da1d..33cd8b4480 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala
@@ -44,7 +44,7 @@ private[hive] case class CreateViewAsSelect(
assert(tableDesc.schema == Nil || tableDesc.schema.length == childSchema.length)
assert(tableDesc.viewText.isDefined)
- private val tableIdentifier = tableDesc.name
+ private val tableIdentifier = tableDesc.identifier
override def run(sqlContext: SQLContext): Seq[Row] = {
val hiveContext = sqlContext.asInstanceOf[HiveContext]
@@ -116,7 +116,7 @@ private[hive] case class CreateViewAsSelect(
}
val viewText = tableDesc.viewText.get
- val viewName = quote(tableDesc.name.table)
+ val viewName = quote(tableDesc.identifier.table)
s"SELECT $viewOutput FROM ($viewText) $viewName"
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala
new file mode 100644
index 0000000000..a97b65e27b
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala
@@ -0,0 +1,503 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.hive.execution
+
+import scala.collection.JavaConverters._
+
+import org.antlr.v4.runtime.{ParserRuleContext, Token}
+import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.conf.HiveConf.ConfVars
+import org.apache.hadoop.hive.ql.parse.EximUtil
+import org.apache.hadoop.hive.ql.session.SessionState
+import org.apache.hadoop.hive.serde.serdeConstants
+import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
+
+import org.apache.spark.sql.catalyst.catalog._
+import org.apache.spark.sql.catalyst.parser._
+import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.SparkSqlAstBuilder
+import org.apache.spark.sql.execution.command.{CreateTable, CreateTableLike}
+import org.apache.spark.sql.hive.{CreateTableAsSelect => CTAS, CreateViewAsSelect => CreateView, HiveSerDe}
+import org.apache.spark.sql.hive.{HiveGenericUDTF, HiveMetastoreTypes, HiveSerDe}
+import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper
+
+/**
+ * Concrete parser for HiveQl statements.
+ */
+object HiveSqlParser extends AbstractSqlParser {
+ val astBuilder = new HiveSqlAstBuilder
+
+ override protected def nativeCommand(sqlText: String): LogicalPlan = {
+ HiveNativeCommand(sqlText)
+ }
+}
+
+/**
+ * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier.
+ */
+class HiveSqlAstBuilder extends SparkSqlAstBuilder {
+ import ParserUtils._
+
+ /**
+ * Get the current Hive Configuration.
+ */
+ private[this] def hiveConf: HiveConf = {
+ var ss = SessionState.get()
+ // SessionState is lazy initialization, it can be null here
+ if (ss == null) {
+ val original = Thread.currentThread().getContextClassLoader
+ val conf = new HiveConf(classOf[SessionState])
+ conf.setClassLoader(original)
+ ss = new SessionState(conf)
+ SessionState.start(ss)
+ }
+ ss.getConf
+ }
+
+ /**
+ * Pass a command to Hive using a [[HiveNativeCommand]].
+ */
+ override def visitExecuteNativeCommand(
+ ctx: ExecuteNativeCommandContext): LogicalPlan = withOrigin(ctx) {
+ HiveNativeCommand(command(ctx))
+ }
+
+ /**
+ * Fail an unsupported Hive native command.
+ */
+ override def visitFailNativeCommand(
+ ctx: FailNativeCommandContext): LogicalPlan = withOrigin(ctx) {
+ val keywords = if (ctx.kws != null) {
+ Seq(ctx.kws.kw1, ctx.kws.kw2, ctx.kws.kw3).filter(_ != null).map(_.getText).mkString(" ")
+ } else {
+ // SET ROLE is the exception to the rule, because we handle this before other SET commands.
+ "SET ROLE"
+ }
+ throw new ParseException(s"Unsupported operation: $keywords", ctx)
+ }
+
+ /**
+ * Create an [[AddJar]] or [[AddFile]] command depending on the requested resource.
+ */
+ override def visitAddResource(ctx: AddResourceContext): LogicalPlan = withOrigin(ctx) {
+ ctx.identifier.getText.toLowerCase match {
+ case "file" => AddFile(remainder(ctx.identifier).trim)
+ case "jar" => AddJar(remainder(ctx.identifier).trim)
+ case other => throw new ParseException(s"Unsupported resource type '$other'.", ctx)
+ }
+ }
+
+ /**
+ * Create an [[AnalyzeTable]] command. This currently only implements the NOSCAN option (other
+ * options are passed on to Hive) e.g.:
+ * {{{
+ * ANALYZE TABLE table COMPUTE STATISTICS NOSCAN;
+ * }}}
+ */
+ override def visitAnalyze(ctx: AnalyzeContext): LogicalPlan = withOrigin(ctx) {
+ if (ctx.partitionSpec == null &&
+ ctx.identifier != null &&
+ ctx.identifier.getText.toLowerCase == "noscan") {
+ AnalyzeTable(visitTableIdentifier(ctx.tableIdentifier).toString)
+ } else {
+ HiveNativeCommand(command(ctx))
+ }
+ }
+
+ /**
+ * Create a [[CatalogStorageFormat]] for creating tables.
+ */
+ override def visitCreateFileFormat(
+ ctx: CreateFileFormatContext): CatalogStorageFormat = withOrigin(ctx) {
+ (ctx.fileFormat, ctx.storageHandler) match {
+ // Expected format: INPUTFORMAT input_format OUTPUTFORMAT output_format
+ case (c: TableFileFormatContext, null) =>
+ visitTableFileFormat(c)
+ // Expected format: SEQUENCEFILE | TEXTFILE | RCFILE | ORC | PARQUET | AVRO
+ case (c: GenericFileFormatContext, null) =>
+ visitGenericFileFormat(c)
+ case (null, storageHandler) =>
+ throw new ParseException("Operation not allowed: ... STORED BY storage_handler ...", ctx)
+ case _ =>
+ throw new ParseException("expected either STORED AS or STORED BY, not both", ctx)
+ }
+ }
+
+ /**
+ * Create a table, returning either a [[CreateTable]] or a [[CreateTableAsSelect]].
+ *
+ * This is not used to create datasource tables, which is handled through
+ * "CREATE TABLE ... USING ...".
+ *
+ * Note: several features are currently not supported - temporary tables, bucketing,
+ * skewed columns and storage handlers (STORED BY).
+ *
+ * Expected format:
+ * {{{
+ * CREATE [TEMPORARY] [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name
+ * [(col1 data_type [COMMENT col_comment], ...)]
+ * [COMMENT table_comment]
+ * [PARTITIONED BY (col3 data_type [COMMENT col_comment], ...)]
+ * [CLUSTERED BY (col1, ...) [SORTED BY (col1 [ASC|DESC], ...)] INTO num_buckets BUCKETS]
+ * [SKEWED BY (col1, col2, ...) ON ((col_value, col_value, ...), ...) [STORED AS DIRECTORIES]]
+ * [ROW FORMAT row_format]
+ * [STORED AS file_format | STORED BY storage_handler_class [WITH SERDEPROPERTIES (...)]]
+ * [LOCATION path]
+ * [TBLPROPERTIES (property_name=property_value, ...)]
+ * [AS select_statement];
+ * }}}
+ */
+ override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) {
+ val (name, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader)
+ // TODO: implement temporary tables
+ if (temp) {
+ throw new ParseException(
+ "CREATE TEMPORARY TABLE is not supported yet. " +
+ "Please use registerTempTable as an alternative.", ctx)
+ }
+ if (ctx.skewSpec != null) {
+ throw new ParseException("Operation not allowed: CREATE TABLE ... SKEWED BY ...", ctx)
+ }
+ if (ctx.bucketSpec != null) {
+ throw new ParseException("Operation not allowed: CREATE TABLE ... CLUSTERED BY ...", ctx)
+ }
+ val tableType = if (external) {
+ CatalogTableType.EXTERNAL_TABLE
+ } else {
+ CatalogTableType.MANAGED_TABLE
+ }
+ val comment = Option(ctx.STRING).map(string)
+ val partitionCols = Option(ctx.partitionColumns).toSeq.flatMap(visitCatalogColumns)
+ val cols = Option(ctx.columns).toSeq.flatMap(visitCatalogColumns)
+ val properties = Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty)
+ val selectQuery = Option(ctx.query).map(plan)
+
+ // Note: Hive requires partition columns to be distinct from the schema, so we need
+ // to include the partition columns here explicitly
+ val schema = cols ++ partitionCols
+
+ // Storage format
+ val defaultStorage: CatalogStorageFormat = {
+ val defaultStorageType = hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTFILEFORMAT)
+ val defaultHiveSerde = HiveSerDe.sourceToSerDe(defaultStorageType, hiveConf)
+ CatalogStorageFormat(
+ locationUri = None,
+ inputFormat = defaultHiveSerde.flatMap(_.inputFormat)
+ .orElse(Some("org.apache.hadoop.mapred.TextInputFormat")),
+ outputFormat = defaultHiveSerde.flatMap(_.outputFormat)
+ .orElse(Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")),
+ // Note: Keep this unspecified because we use the presence of the serde to decide
+ // whether to convert a table created by CTAS to a datasource table.
+ serde = None,
+ serdeProperties = Map())
+ }
+ val fileStorage = Option(ctx.createFileFormat).map(visitCreateFileFormat)
+ .getOrElse(EmptyStorageFormat)
+ val rowStorage = Option(ctx.rowFormat).map(visitRowFormat).getOrElse(EmptyStorageFormat)
+ val location = Option(ctx.locationSpec).map(visitLocationSpec)
+ val storage = CatalogStorageFormat(
+ locationUri = location,
+ inputFormat = fileStorage.inputFormat.orElse(defaultStorage.inputFormat),
+ outputFormat = fileStorage.outputFormat.orElse(defaultStorage.outputFormat),
+ serde = rowStorage.serde.orElse(fileStorage.serde).orElse(defaultStorage.serde),
+ serdeProperties = rowStorage.serdeProperties ++ fileStorage.serdeProperties)
+
+ // TODO support the sql text - have a proper location for this!
+ val tableDesc = CatalogTable(
+ identifier = name,
+ tableType = tableType,
+ storage = storage,
+ schema = schema,
+ partitionColumnNames = partitionCols.map(_.name),
+ properties = properties,
+ comment = comment)
+
+ selectQuery match {
+ case Some(q) => CTAS(tableDesc, q, ifNotExists)
+ case None => CreateTable(tableDesc, ifNotExists)
+ }
+ }
+
+ /**
+ * Create a [[CreateTableLike]] command.
+ */
+ override def visitCreateTableLike(ctx: CreateTableLikeContext): LogicalPlan = withOrigin(ctx) {
+ val targetTable = visitTableIdentifier(ctx.target)
+ val sourceTable = visitTableIdentifier(ctx.source)
+ CreateTableLike(targetTable, sourceTable, ctx.EXISTS != null)
+ }
+
+ /**
+ * Create or replace a view. This creates a [[CreateViewAsSelect]] command.
+ *
+ * For example:
+ * {{{
+ * CREATE VIEW [IF NOT EXISTS] [db_name.]view_name
+ * [(column_name [COMMENT column_comment], ...) ]
+ * [COMMENT view_comment]
+ * [TBLPROPERTIES (property_name = property_value, ...)]
+ * AS SELECT ...;
+ * }}}
+ */
+ override def visitCreateView(ctx: CreateViewContext): LogicalPlan = withOrigin(ctx) {
+ if (ctx.identifierList != null) {
+ throw new ParseException(s"Operation not allowed: partitioned views", ctx)
+ } else {
+ val identifiers = Option(ctx.identifierCommentList).toSeq.flatMap(_.identifierComment.asScala)
+ val schema = identifiers.map { ic =>
+ CatalogColumn(ic.identifier.getText, null, nullable = true, Option(ic.STRING).map(string))
+ }
+ createView(
+ ctx,
+ ctx.tableIdentifier,
+ comment = Option(ctx.STRING).map(string),
+ schema,
+ ctx.query,
+ Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty),
+ ctx.EXISTS != null,
+ ctx.REPLACE != null
+ )
+ }
+ }
+
+ /**
+ * Alter the query of a view. This creates a [[CreateViewAsSelect]] command.
+ */
+ override def visitAlterViewQuery(ctx: AlterViewQueryContext): LogicalPlan = withOrigin(ctx) {
+ createView(
+ ctx,
+ ctx.tableIdentifier,
+ comment = None,
+ Seq.empty,
+ ctx.query,
+ Map.empty,
+ allowExist = false,
+ replace = true)
+ }
+
+ /**
+ * Create a [[CreateViewAsSelect]] command.
+ */
+ private def createView(
+ ctx: ParserRuleContext,
+ name: TableIdentifierContext,
+ comment: Option[String],
+ schema: Seq[CatalogColumn],
+ query: QueryContext,
+ properties: Map[String, String],
+ allowExist: Boolean,
+ replace: Boolean): LogicalPlan = {
+ val sql = Option(source(query))
+ val tableDesc = CatalogTable(
+ identifier = visitTableIdentifier(name),
+ tableType = CatalogTableType.VIRTUAL_VIEW,
+ schema = schema,
+ storage = EmptyStorageFormat,
+ properties = properties,
+ viewOriginalText = sql,
+ viewText = sql,
+ comment = comment)
+ CreateView(tableDesc, plan(query), allowExist, replace, command(ctx))
+ }
+
+ /**
+ * Create a [[HiveScriptIOSchema]].
+ */
+ override protected def withScriptIOSchema(
+ ctx: QuerySpecificationContext,
+ inRowFormat: RowFormatContext,
+ recordWriter: Token,
+ outRowFormat: RowFormatContext,
+ recordReader: Token,
+ schemaLess: Boolean): HiveScriptIOSchema = {
+ if (recordWriter != null || recordReader != null) {
+ throw new ParseException(
+ "Unsupported operation: Used defined record reader/writer classes.", ctx)
+ }
+
+ // Decode and input/output format.
+ type Format = (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String])
+ def format(fmt: RowFormatContext, confVar: ConfVars): Format = fmt match {
+ case c: RowFormatDelimitedContext =>
+ // TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema
+ // expects a seq of pairs in which the old parsers' token names are used as keys.
+ // Transforming the result of visitRowFormatDelimited would be quite a bit messier than
+ // retrieving the key value pairs ourselves.
+ def entry(key: String, value: Token): Seq[(String, String)] = {
+ Option(value).map(t => key -> t.getText).toSeq
+ }
+ val entries = entry("TOK_TABLEROWFORMATFIELD", c.fieldsTerminatedBy) ++
+ entry("TOK_TABLEROWFORMATCOLLITEMS", c.collectionItemsTerminatedBy) ++
+ entry("TOK_TABLEROWFORMATMAPKEYS", c.keysTerminatedBy) ++
+ entry("TOK_TABLEROWFORMATLINES", c.linesSeparatedBy) ++
+ entry("TOK_TABLEROWFORMATNULL", c.nullDefinedAs)
+
+ (entries, None, Seq.empty, None)
+
+ case c: RowFormatSerdeContext =>
+ // Use a serde format.
+ val CatalogStorageFormat(None, None, None, Some(name), props) = visitRowFormatSerde(c)
+
+ // SPARK-10310: Special cases LazySimpleSerDe
+ val recordHandler = if (name == classOf[LazySimpleSerDe].getCanonicalName) {
+ Option(hiveConf.getVar(confVar))
+ } else {
+ None
+ }
+ (Seq.empty, Option(name), props.toSeq, recordHandler)
+
+ case null =>
+ // Use default (serde) format.
+ val name = hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)
+ val props = Seq(serdeConstants.FIELD_DELIM -> "\t")
+ val recordHandler = Option(hiveConf.getVar(confVar))
+ (Nil, Option(name), props, recordHandler)
+ }
+
+ val (inFormat, inSerdeClass, inSerdeProps, reader) =
+ format(inRowFormat, ConfVars.HIVESCRIPTRECORDREADER)
+
+ val (outFormat, outSerdeClass, outSerdeProps, writer) =
+ format(inRowFormat, ConfVars.HIVESCRIPTRECORDWRITER)
+
+ HiveScriptIOSchema(
+ inFormat, outFormat,
+ inSerdeClass, outSerdeClass,
+ inSerdeProps, outSerdeProps,
+ reader, writer,
+ schemaLess)
+ }
+
+ /**
+ * Create location string.
+ */
+ override def visitLocationSpec(ctx: LocationSpecContext): String = {
+ EximUtil.relativeToAbsolutePath(hiveConf, super.visitLocationSpec(ctx))
+ }
+
+ /** Empty storage format for default values and copies. */
+ private val EmptyStorageFormat = CatalogStorageFormat(None, None, None, None, Map.empty)
+
+ /**
+ * Create a [[CatalogStorageFormat]].
+ */
+ override def visitTableFileFormat(
+ ctx: TableFileFormatContext): CatalogStorageFormat = withOrigin(ctx) {
+ EmptyStorageFormat.copy(
+ inputFormat = Option(string(ctx.inFmt)),
+ outputFormat = Option(string(ctx.outFmt)),
+ serde = Option(ctx.serdeCls).map(string)
+ )
+ }
+
+ /**
+ * Resolve a [[HiveSerDe]] based on the name given and return it as a [[CatalogStorageFormat]].
+ */
+ override def visitGenericFileFormat(
+ ctx: GenericFileFormatContext): CatalogStorageFormat = withOrigin(ctx) {
+ val source = ctx.identifier.getText
+ HiveSerDe.sourceToSerDe(source, hiveConf) match {
+ case Some(s) =>
+ EmptyStorageFormat.copy(
+ inputFormat = s.inputFormat,
+ outputFormat = s.outputFormat,
+ serde = s.serde)
+ case None =>
+ throw new ParseException(s"Unrecognized file format in STORED AS clause: $source", ctx)
+ }
+ }
+
+ /**
+ * Create a [[RowFormat]] used for creating tables.
+ *
+ * Example format:
+ * {{{
+ * SERDE serde_name [WITH SERDEPROPERTIES (k1=v1, k2=v2, ...)]
+ * }}}
+ *
+ * OR
+ *
+ * {{{
+ * DELIMITED [FIELDS TERMINATED BY char [ESCAPED BY char]]
+ * [COLLECTION ITEMS TERMINATED BY char]
+ * [MAP KEYS TERMINATED BY char]
+ * [LINES TERMINATED BY char]
+ * [NULL DEFINED AS char]
+ * }}}
+ */
+ private def visitRowFormat(ctx: RowFormatContext): CatalogStorageFormat = withOrigin(ctx) {
+ ctx match {
+ case serde: RowFormatSerdeContext => visitRowFormatSerde(serde)
+ case delimited: RowFormatDelimitedContext => visitRowFormatDelimited(delimited)
+ }
+ }
+
+ /**
+ * Create SERDE row format name and properties pair.
+ */
+ override def visitRowFormatSerde(
+ ctx: RowFormatSerdeContext): CatalogStorageFormat = withOrigin(ctx) {
+ import ctx._
+ EmptyStorageFormat.copy(
+ serde = Option(string(name)),
+ serdeProperties = Option(tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty))
+ }
+
+ /**
+ * Create a delimited row format properties object.
+ */
+ override def visitRowFormatDelimited(
+ ctx: RowFormatDelimitedContext): CatalogStorageFormat = withOrigin(ctx) {
+ // Collect the entries if any.
+ def entry(key: String, value: Token): Seq[(String, String)] = {
+ Option(value).toSeq.map(x => key -> string(x))
+ }
+ // TODO we need proper support for the NULL format.
+ val entries = entry(serdeConstants.FIELD_DELIM, ctx.fieldsTerminatedBy) ++
+ entry(serdeConstants.SERIALIZATION_FORMAT, ctx.fieldsTerminatedBy) ++
+ entry(serdeConstants.ESCAPE_CHAR, ctx.escapedBy) ++
+ entry(serdeConstants.COLLECTION_DELIM, ctx.collectionItemsTerminatedBy) ++
+ entry(serdeConstants.MAPKEY_DELIM, ctx.keysTerminatedBy) ++
+ Option(ctx.linesSeparatedBy).toSeq.map { token =>
+ val value = string(token)
+ assert(
+ value == "\n",
+ s"LINES TERMINATED BY only supports newline '\\n' right now: $value",
+ ctx)
+ serdeConstants.LINE_DELIM -> value
+ }
+ EmptyStorageFormat.copy(serdeProperties = entries.toMap)
+ }
+
+ /**
+ * Create a sequence of [[CatalogColumn]]s from a column list
+ */
+ private def visitCatalogColumns(ctx: ColTypeListContext): Seq[CatalogColumn] = withOrigin(ctx) {
+ ctx.colType.asScala.map { col =>
+ CatalogColumn(
+ col.identifier.getText.toLowerCase,
+ // Note: for types like "STRUCT<myFirstName: STRING, myLastName: STRING>" we can't
+ // just convert the whole type string to lower case, otherwise the struct field names
+ // will no longer be case sensitive. Instead, we rely on our parser to get the proper
+ // case before passing it to Hive.
+ CatalystSqlParser.parseDataType(col.dataType.getText).simpleString,
+ nullable = true,
+ Option(col.STRING).map(string))
+ }
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
index cd26a68f35..06badff474 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
@@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.{BucketSpec, DataSource, LogicalRelation}
import org.apache.spark.sql.hive.HiveContext
@@ -47,36 +46,6 @@ case class AnalyzeTable(tableName: String) extends RunnableCommand {
}
}
-/**
- * Drops a table from the metastore and removes it if it is cached.
- */
-private[hive]
-case class DropTable(
- tableName: String,
- ifExists: Boolean) extends RunnableCommand {
-
- override def run(sqlContext: SQLContext): Seq[Row] = {
- val hiveContext = sqlContext.asInstanceOf[HiveContext]
- val ifExistsClause = if (ifExists) "IF EXISTS " else ""
- try {
- hiveContext.cacheManager.tryUncacheQuery(hiveContext.table(tableName))
- } catch {
- // This table's metadata is not in Hive metastore (e.g. the table does not exist).
- case _: org.apache.hadoop.hive.ql.metadata.InvalidTableException =>
- case _: org.apache.spark.sql.catalyst.analysis.NoSuchTableException =>
- // Other Throwables can be caused by users providing wrong parameters in OPTIONS
- // (e.g. invalid paths). We catch it and log a warning message.
- // Users should be able to drop such kinds of tables regardless if there is an error.
- case e: Throwable => log.warn(s"${e.getMessage}", e)
- }
- hiveContext.invalidateTable(tableName)
- hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName")
- hiveContext.sessionState.catalog.dropTable(
- TableIdentifier(tableName), ignoreIfNotExists = true)
- Seq.empty[Row]
- }
-}
-
private[hive]
case class AddJar(path: String) extends RunnableCommand {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index efaa052370..784b018353 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.hive
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
-import scala.util.Try
import org.apache.hadoop.hive.ql.exec._
import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
@@ -31,118 +30,14 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, O
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.{analysis, InternalRow}
-import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.hive.HiveShim._
-import org.apache.spark.sql.hive.client.HiveClientImpl
import org.apache.spark.sql.types._
-private[hive] class HiveFunctionRegistry(
- underlying: analysis.FunctionRegistry,
- executionHive: HiveClientImpl)
- extends analysis.FunctionRegistry with HiveInspectors {
-
- def getFunctionInfo(name: String): FunctionInfo = {
- // Hive Registry need current database to lookup function
- // TODO: the current database of executionHive should be consistent with metadataHive
- executionHive.withHiveState {
- FunctionRegistry.getFunctionInfo(name)
- }
- }
-
- override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
- Try(underlying.lookupFunction(name, children)).getOrElse {
- // We only look it up to see if it exists, but do not include it in the HiveUDF since it is
- // not always serializable.
- val functionInfo: FunctionInfo =
- Option(getFunctionInfo(name.toLowerCase)).getOrElse(
- throw new AnalysisException(s"undefined function $name"))
-
- val functionClassName = functionInfo.getFunctionClass.getName
-
- // When we instantiate hive UDF wrapper class, we may throw exception if the input expressions
- // don't satisfy the hive UDF, such as type mismatch, input number mismatch, etc. Here we
- // catch the exception and throw AnalysisException instead.
- try {
- if (classOf[GenericUDFMacro].isAssignableFrom(functionInfo.getFunctionClass)) {
- val udf = HiveGenericUDF(
- name, new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children)
- udf.dataType // Force it to check input data types.
- udf
- } else if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
- val udf = HiveSimpleUDF(name, new HiveFunctionWrapper(functionClassName), children)
- udf.dataType // Force it to check input data types.
- udf
- } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
- val udf = HiveGenericUDF(name, new HiveFunctionWrapper(functionClassName), children)
- udf.dataType // Force it to check input data types.
- udf
- } else if (
- classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
- val udaf = HiveUDAFFunction(name, new HiveFunctionWrapper(functionClassName), children)
- udaf.dataType // Force it to check input data types.
- udaf
- } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
- val udaf = HiveUDAFFunction(
- name, new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true)
- udaf.dataType // Force it to check input data types.
- udaf
- } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
- val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(functionClassName), children)
- udtf.elementTypes // Force it to check input data types.
- udtf
- } else {
- throw new AnalysisException(s"No handler for udf ${functionInfo.getFunctionClass}")
- }
- } catch {
- case analysisException: AnalysisException =>
- // If the exception is an AnalysisException, just throw it.
- throw analysisException
- case throwable: Throwable =>
- // If there is any other error, we throw an AnalysisException.
- val errorMessage = s"No handler for Hive udf ${functionInfo.getFunctionClass} " +
- s"because: ${throwable.getMessage}."
- throw new AnalysisException(errorMessage)
- }
- }
- }
-
- override def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder)
- : Unit = underlying.registerFunction(name, info, builder)
-
- /* List all of the registered function names. */
- override def listFunction(): Seq[String] = {
- (FunctionRegistry.getFunctionNames.asScala ++ underlying.listFunction()).toList.sorted
- }
-
- /* Get the class of the registered function by specified name. */
- override def lookupFunction(name: String): Option[ExpressionInfo] = {
- underlying.lookupFunction(name).orElse(
- Try {
- val info = getFunctionInfo(name)
- val annotation = info.getFunctionClass.getAnnotation(classOf[Description])
- if (annotation != null) {
- Some(new ExpressionInfo(
- info.getFunctionClass.getCanonicalName,
- annotation.name(),
- annotation.value(),
- annotation.extended()))
- } else {
- Some(new ExpressionInfo(
- info.getFunctionClass.getCanonicalName,
- name,
- null,
- null))
- }
- }.getOrElse(None))
- }
-}
-
private[hive] case class HiveSimpleUDF(
name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
extends Expression with HiveInspectors with CodegenFallback with Logging {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
index 7c4a0a0c0f..21591ec093 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala
@@ -33,7 +33,6 @@ import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit}
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
-import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.{HadoopRDD, RDD}
import org.apache.spark.sql.{Row, SQLContext}
@@ -45,7 +44,6 @@ import org.apache.spark.sql.hive.{HiveInspectors, HiveMetastoreTypes, HiveShim}
import org.apache.spark.sql.sources.{Filter, _}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration
-import org.apache.spark.util.collection.BitSet
private[sql] class DefaultSource
extends FileFormat with DataSourceRegister with Serializable {
@@ -111,23 +109,11 @@ private[sql] class DefaultSource
}
}
- override def buildInternalScan(
- sqlContext: SQLContext,
- dataSchema: StructType,
- requiredColumns: Array[String],
- filters: Array[Filter],
- bucketSet: Option[BitSet],
- inputFiles: Seq[FileStatus],
- broadcastedConf: Broadcast[SerializableConfiguration],
- options: Map[String, String]): RDD[InternalRow] = {
- val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes
- OrcTableScan(sqlContext, output, filters, inputFiles).execute()
- }
-
override def buildReader(
sqlContext: SQLContext,
- partitionSchema: StructType,
dataSchema: StructType,
+ partitionSchema: StructType,
+ requiredSchema: StructType,
filters: Seq[Filter],
options: Map[String, String]): (PartitionedFile) => Iterator[InternalRow] = {
val orcConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration)
@@ -145,15 +131,15 @@ private[sql] class DefaultSource
(file: PartitionedFile) => {
val conf = broadcastedConf.value.value
- // SPARK-8501: Empty ORC files always have an empty schema stored in their footer. In this
- // case, `OrcFileOperator.readSchema` returns `None`, and we can simply return an empty
- // iterator.
+ // SPARK-8501: Empty ORC files always have an empty schema stored in their footer. In this
+ // case, `OrcFileOperator.readSchema` returns `None`, and we can't read the underlying file
+ // using the given physical schema. Instead, we simply return an empty iterator.
val maybePhysicalSchema = OrcFileOperator.readSchema(Seq(file.filePath), Some(conf))
if (maybePhysicalSchema.isEmpty) {
Iterator.empty
} else {
val physicalSchema = maybePhysicalSchema.get
- OrcRelation.setRequiredColumns(conf, physicalSchema, dataSchema)
+ OrcRelation.setRequiredColumns(conf, physicalSchema, requiredSchema)
val orcRecordReader = {
val job = Job.getInstance(conf)
@@ -171,11 +157,11 @@ private[sql] class DefaultSource
// Unwraps `OrcStruct`s to `UnsafeRow`s
val unsafeRowIterator = OrcRelation.unwrapOrcStructs(
- file.filePath, conf, dataSchema, new RecordReaderIterator[OrcStruct](orcRecordReader)
+ file.filePath, conf, requiredSchema, new RecordReaderIterator[OrcStruct](orcRecordReader)
)
// Appends partition values
- val fullOutput = dataSchema.toAttributes ++ partitionSchema.toAttributes
+ val fullOutput = requiredSchema.toAttributes ++ partitionSchema.toAttributes
val joinedRow = new JoinedRow()
val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index a1785ca038..7f6ca21782 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -78,7 +78,7 @@ class TestHiveContext private[hive](
executionHive: HiveClientImpl,
metadataHive: HiveClient,
isRootContext: Boolean,
- hiveCatalog: HiveCatalog,
+ hiveCatalog: HiveExternalCatalog,
val warehousePath: File,
val scratchDirPath: File,
metastoreTemporaryConf: Map[String, String])
@@ -110,7 +110,7 @@ class TestHiveContext private[hive](
executionHive,
metadataHive,
true,
- new HiveCatalog(metadataHive),
+ new HiveExternalCatalog(metadataHive),
warehousePath,
scratchDirPath,
metastoreTemporaryConf)
@@ -201,8 +201,13 @@ class TestHiveContext private[hive](
}
override lazy val functionRegistry = {
- new TestHiveFunctionRegistry(
- org.apache.spark.sql.catalyst.analysis.FunctionRegistry.builtin.copy(), self.executionHive)
+ // We use TestHiveFunctionRegistry at here to track functions that have been explicitly
+ // unregistered (through TestHiveFunctionRegistry.unregisterFunction method).
+ val fr = new TestHiveFunctionRegistry
+ org.apache.spark.sql.catalyst.analysis.FunctionRegistry.expressions.foreach {
+ case (name, (info, builder)) => fr.registerFunction(name, info, builder)
+ }
+ fr
}
}
@@ -380,8 +385,8 @@ class TestHiveContext private[hive](
""".stripMargin.cmd,
s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/episodes.avro")}' INTO TABLE episodes".cmd
),
- // THIS TABLE IS NOT THE SAME AS THE HIVE TEST TABLE episodes_partitioned AS DYNAMIC PARITIONING
- // IS NOT YET SUPPORTED
+ // THIS TABLE IS NOT THE SAME AS THE HIVE TEST TABLE episodes_partitioned AS DYNAMIC
+ // PARTITIONING IS NOT YET SUPPORTED
TestTable("episodes_part",
s"""CREATE TABLE episodes_part (title STRING, air_date STRING, doctor INT)
|PARTITIONED BY (doctor_pt INT)
@@ -528,19 +533,18 @@ class TestHiveContext private[hive](
}
-private[hive] class TestHiveFunctionRegistry(fr: SimpleFunctionRegistry, client: HiveClientImpl)
- extends HiveFunctionRegistry(fr, client) {
+private[hive] class TestHiveFunctionRegistry extends SimpleFunctionRegistry {
private val removedFunctions =
collection.mutable.ArrayBuffer.empty[(String, (ExpressionInfo, FunctionBuilder))]
def unregisterFunction(name: String): Unit = {
- fr.functionBuilders.remove(name).foreach(f => removedFunctions += name -> f)
+ functionBuilders.remove(name).foreach(f => removedFunctions += name -> f)
}
def restore(): Unit = {
removedFunctions.foreach {
- case (name, (info, builder)) => fr.registerFunction(name, info, builder)
+ case (name, (info, builder)) => registerFunction(name, info, builder)
}
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala
index 34b2edb44b..f262ef62be 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala
@@ -24,9 +24,7 @@ import org.apache.spark.SparkFunSuite
/**
* Verify that some classes load and that others are not found on the classpath.
*
- *
- * This is used to detect classpath and shading conflict, especially between
- * Spark's required Kryo version and that which can be found in some Hive versions.
+ * This is used to detect classpath and shading conflicts.
*/
class ClasspathDependenciesSuite extends SparkFunSuite {
private val classloader = this.getClass.getClassLoader
@@ -40,10 +38,6 @@ class ClasspathDependenciesSuite extends SparkFunSuite {
classloader.loadClass(classname)
}
- private def assertLoads(classes: String*): Unit = {
- classes.foreach(assertLoads)
- }
-
private def findResource(classname: String): URL = {
val resource = resourceName(classname)
classloader.getResource(resource)
@@ -63,17 +57,12 @@ class ClasspathDependenciesSuite extends SparkFunSuite {
}
}
- private def assertClassNotFound(classes: String*): Unit = {
- classes.foreach(assertClassNotFound)
+ test("shaded Protobuf") {
+ assertLoads("org.apache.hive.com.google.protobuf.ServiceException")
}
- private val KRYO = "com.esotericsoftware.kryo.Kryo"
-
- private val SPARK_HIVE = "org.apache.hive."
- private val SPARK_SHADED = "org.spark-project.hive.shaded."
-
- test("shaded Protobuf") {
- assertLoads(SPARK_SHADED + "com.google.protobuf.ServiceException")
+ test("shaded Kryo") {
+ assertLoads("org.apache.hive.com.esotericsoftware.kryo.Kryo")
}
test("hive-common") {
@@ -86,25 +75,13 @@ class ClasspathDependenciesSuite extends SparkFunSuite {
private val STD_INSTANTIATOR = "org.objenesis.strategy.StdInstantiatorStrategy"
- test("unshaded kryo") {
- assertLoads(KRYO, STD_INSTANTIATOR)
- }
-
test("Forbidden Dependencies") {
- assertClassNotFound(
- SPARK_HIVE + KRYO,
- SPARK_SHADED + KRYO,
- "org.apache.hive." + KRYO,
- "com.esotericsoftware.shaded." + STD_INSTANTIATOR,
- SPARK_HIVE + "com.esotericsoftware.shaded." + STD_INSTANTIATOR,
- "org.apache.hive.com.esotericsoftware.shaded." + STD_INSTANTIATOR
- )
+ assertClassNotFound("com.esotericsoftware.shaded." + STD_INSTANTIATOR)
+ assertClassNotFound("org.apache.hive.com.esotericsoftware.shaded." + STD_INSTANTIATOR)
}
test("parquet-hadoop-bundle") {
- assertLoads(
- "parquet.hadoop.ParquetOutputFormat",
- "parquet.hadoop.ParquetInputFormat"
- )
+ assertLoads("parquet.hadoop.ParquetOutputFormat")
+ assertLoads("parquet.hadoop.ParquetInputFormat")
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
index 4b6da7cd33..d9664680f4 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
@@ -22,8 +22,8 @@ import scala.util.Try
import org.scalatest.BeforeAndAfterEach
import org.apache.spark.sql.{AnalysisException, QueryTest}
-import org.apache.spark.sql.catalyst.parser.ParseDriver
import org.apache.spark.sql.catalyst.util.quietly
+import org.apache.spark.sql.hive.execution.HiveSqlParser
import org.apache.spark.sql.hive.test.TestHiveSingleton
class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterEach {
@@ -131,7 +131,7 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd
* @param token a unique token in the string that should be indicated by the exception
*/
def positionTest(name: String, query: String, token: String): Unit = {
- def ast = ParseDriver.parsePlan(query, hiveContext.conf)
+ def ast = HiveSqlParser.parsePlan(query)
def parseTree = Try(quietly(ast.treeString)).getOrElse("<failed to parse>")
test(name) {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala
index 75930086ff..bf85d71c66 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala
@@ -213,8 +213,8 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils {
checkSqlGeneration("SELECT space(2)")
checkSqlGeneration("SELECT split('aa2bb3cc', '[1-9]+')")
checkSqlGeneration("SELECT space(2)")
- checkSqlGeneration("SELECT substr('This is a test', 'is')")
- checkSqlGeneration("SELECT substring('This is a test', 'is')")
+ checkSqlGeneration("SELECT substr('This is a test', 1)")
+ checkSqlGeneration("SELECT substring('This is a test', 1)")
checkSqlGeneration("SELECT substring_index('www.apache.org','.',1)")
checkSqlGeneration("SELECT translate('translate', 'rnlt', '123')")
checkSqlGeneration("SELECT trim(' SparkSql ')")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala
new file mode 100644
index 0000000000..110c6d19d8
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala
@@ -0,0 +1,582 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import org.apache.hadoop.hive.serde.serdeConstants
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable, CatalogTableType}
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans
+import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan
+import org.apache.spark.sql.catalyst.expressions.JsonTuple
+import org.apache.spark.sql.catalyst.parser.ParseException
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{Generate, ScriptTransformation}
+import org.apache.spark.sql.execution.command.{CreateTable, CreateTableLike}
+import org.apache.spark.sql.hive.execution.{HiveNativeCommand, HiveSqlParser}
+
+class HiveDDLCommandSuite extends PlanTest {
+ val parser = HiveSqlParser
+
+ private def extractTableDesc(sql: String): (CatalogTable, Boolean) = {
+ parser.parsePlan(sql).collect {
+ case CreateTable(desc, allowExisting) => (desc, allowExisting)
+ case CreateTableAsSelect(desc, _, allowExisting) => (desc, allowExisting)
+ case CreateViewAsSelect(desc, _, allowExisting, _, _) => (desc, allowExisting)
+ }.head
+ }
+
+ private def assertUnsupported(sql: String): Unit = {
+ val e = intercept[ParseException] {
+ parser.parsePlan(sql)
+ }
+ assert(e.getMessage.toLowerCase.contains("unsupported"))
+ }
+
+ test("Test CTAS #1") {
+ val s1 =
+ """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view
+ |(viewTime INT,
+ |userid BIGINT,
+ |page_url STRING,
+ |referrer_url STRING,
+ |ip STRING COMMENT 'IP Address of the User',
+ |country STRING COMMENT 'country of origination')
+ |COMMENT 'This is the staging page view table'
+ |PARTITIONED BY (dt STRING COMMENT 'date type', hour STRING COMMENT 'hour of the day')
+ |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\054' STORED AS RCFILE
+ |LOCATION '/user/external/page_view'
+ |TBLPROPERTIES ('p1'='v1', 'p2'='v2')
+ |AS SELECT * FROM src""".stripMargin
+
+ val (desc, exists) = extractTableDesc(s1)
+ assert(exists)
+ assert(desc.identifier.database == Some("mydb"))
+ assert(desc.identifier.table == "page_view")
+ assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE)
+ assert(desc.storage.locationUri == Some("/user/external/page_view"))
+ assert(desc.schema ==
+ CatalogColumn("viewtime", "int") ::
+ CatalogColumn("userid", "bigint") ::
+ CatalogColumn("page_url", "string") ::
+ CatalogColumn("referrer_url", "string") ::
+ CatalogColumn("ip", "string", comment = Some("IP Address of the User")) ::
+ CatalogColumn("country", "string", comment = Some("country of origination")) ::
+ CatalogColumn("dt", "string", comment = Some("date type")) ::
+ CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil)
+ assert(desc.comment == Some("This is the staging page view table"))
+ // TODO will be SQLText
+ assert(desc.viewText.isEmpty)
+ assert(desc.viewOriginalText.isEmpty)
+ assert(desc.partitionColumns ==
+ CatalogColumn("dt", "string", comment = Some("date type")) ::
+ CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil)
+ assert(desc.storage.serdeProperties ==
+ Map((serdeConstants.SERIALIZATION_FORMAT, "\u002C"), (serdeConstants.FIELD_DELIM, "\u002C")))
+ assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat"))
+ assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat"))
+ assert(desc.storage.serde ==
+ Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe"))
+ assert(desc.properties == Map(("p1", "v1"), ("p2", "v2")))
+ }
+
+ test("Test CTAS #2") {
+ val s2 =
+ """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view
+ |(viewTime INT,
+ |userid BIGINT,
+ |page_url STRING,
+ |referrer_url STRING,
+ |ip STRING COMMENT 'IP Address of the User',
+ |country STRING COMMENT 'country of origination')
+ |COMMENT 'This is the staging page view table'
+ |PARTITIONED BY (dt STRING COMMENT 'date type', hour STRING COMMENT 'hour of the day')
+ |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe'
+ | STORED AS
+ | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat'
+ | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat'
+ |LOCATION '/user/external/page_view'
+ |TBLPROPERTIES ('p1'='v1', 'p2'='v2')
+ |AS SELECT * FROM src""".stripMargin
+
+ val (desc, exists) = extractTableDesc(s2)
+ assert(exists)
+ assert(desc.identifier.database == Some("mydb"))
+ assert(desc.identifier.table == "page_view")
+ assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE)
+ assert(desc.storage.locationUri == Some("/user/external/page_view"))
+ assert(desc.schema ==
+ CatalogColumn("viewtime", "int") ::
+ CatalogColumn("userid", "bigint") ::
+ CatalogColumn("page_url", "string") ::
+ CatalogColumn("referrer_url", "string") ::
+ CatalogColumn("ip", "string", comment = Some("IP Address of the User")) ::
+ CatalogColumn("country", "string", comment = Some("country of origination")) ::
+ CatalogColumn("dt", "string", comment = Some("date type")) ::
+ CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil)
+ // TODO will be SQLText
+ assert(desc.comment == Some("This is the staging page view table"))
+ assert(desc.viewText.isEmpty)
+ assert(desc.viewOriginalText.isEmpty)
+ assert(desc.partitionColumns ==
+ CatalogColumn("dt", "string", comment = Some("date type")) ::
+ CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil)
+ assert(desc.storage.serdeProperties == Map())
+ assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat"))
+ assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat"))
+ assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe"))
+ assert(desc.properties == Map(("p1", "v1"), ("p2", "v2")))
+ }
+
+ test("Test CTAS #3") {
+ val s3 = """CREATE TABLE page_view AS SELECT * FROM src"""
+ val (desc, exists) = extractTableDesc(s3)
+ assert(exists == false)
+ assert(desc.identifier.database == None)
+ assert(desc.identifier.table == "page_view")
+ assert(desc.tableType == CatalogTableType.MANAGED_TABLE)
+ assert(desc.storage.locationUri == None)
+ assert(desc.schema == Seq.empty[CatalogColumn])
+ assert(desc.viewText == None) // TODO will be SQLText
+ assert(desc.viewOriginalText.isEmpty)
+ assert(desc.storage.serdeProperties == Map())
+ assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat"))
+ assert(desc.storage.outputFormat ==
+ Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat"))
+ assert(desc.storage.serde.isEmpty)
+ assert(desc.properties == Map())
+ }
+
+ test("Test CTAS #4") {
+ val s4 =
+ """CREATE TABLE page_view
+ |STORED BY 'storage.handler.class.name' AS SELECT * FROM src""".stripMargin
+ intercept[AnalysisException] {
+ extractTableDesc(s4)
+ }
+ }
+
+ test("Test CTAS #5") {
+ val s5 = """CREATE TABLE ctas2
+ | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe"
+ | WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2")
+ | STORED AS RCFile
+ | TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22")
+ | AS
+ | SELECT key, value
+ | FROM src
+ | ORDER BY key, value""".stripMargin
+ val (desc, exists) = extractTableDesc(s5)
+ assert(exists == false)
+ assert(desc.identifier.database == None)
+ assert(desc.identifier.table == "ctas2")
+ assert(desc.tableType == CatalogTableType.MANAGED_TABLE)
+ assert(desc.storage.locationUri == None)
+ assert(desc.schema == Seq.empty[CatalogColumn])
+ assert(desc.viewText == None) // TODO will be SQLText
+ assert(desc.viewOriginalText.isEmpty)
+ assert(desc.storage.serdeProperties == Map(("serde_p1" -> "p1"), ("serde_p2" -> "p2")))
+ assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat"))
+ assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat"))
+ assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe"))
+ assert(desc.properties == Map(("tbl_p1" -> "p11"), ("tbl_p2" -> "p22")))
+ }
+
+ test("unsupported operations") {
+ intercept[ParseException] {
+ parser.parsePlan(
+ """
+ |CREATE TEMPORARY TABLE ctas2
+ |ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe"
+ |WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2")
+ |STORED AS RCFile
+ |TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22")
+ |AS SELECT key, value FROM src ORDER BY key, value
+ """.stripMargin)
+ }
+ intercept[ParseException] {
+ parser.parsePlan(
+ """
+ |CREATE TABLE user_info_bucketed(user_id BIGINT, firstname STRING, lastname STRING)
+ |CLUSTERED BY(user_id) INTO 256 BUCKETS
+ |AS SELECT key, value FROM src ORDER BY key, value
+ """.stripMargin)
+ }
+ intercept[ParseException] {
+ parser.parsePlan(
+ """
+ |CREATE TABLE user_info_bucketed(user_id BIGINT, firstname STRING, lastname STRING)
+ |SKEWED BY (key) ON (1,5,6)
+ |AS SELECT key, value FROM src ORDER BY key, value
+ """.stripMargin)
+ }
+ intercept[ParseException] {
+ parser.parsePlan(
+ """
+ |SELECT TRANSFORM (key, value) USING 'cat' AS (tKey, tValue)
+ |ROW FORMAT SERDE 'org.apache.hadoop.hive.contrib.serde2.TypedBytesSerDe'
+ |RECORDREADER 'org.apache.hadoop.hive.contrib.util.typedbytes.TypedBytesRecordReader'
+ |FROM testData
+ """.stripMargin)
+ }
+ }
+
+ test("Invalid interval term should throw AnalysisException") {
+ def assertError(sql: String, errorMessage: String): Unit = {
+ val e = intercept[AnalysisException] {
+ parser.parsePlan(sql)
+ }
+ assert(e.getMessage.contains(errorMessage))
+ }
+ assertError("select interval '42-32' year to month",
+ "month 32 outside range [0, 11]")
+ assertError("select interval '5 49:12:15' day to second",
+ "hour 49 outside range [0, 23]")
+ assertError("select interval '.1111111111' second",
+ "nanosecond 1111111111 outside range")
+ }
+
+ test("use native json_tuple instead of hive's UDTF in LATERAL VIEW") {
+ val plan = parser.parsePlan(
+ """
+ |SELECT *
+ |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test
+ |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt AS a, b
+ """.stripMargin)
+
+ assert(plan.children.head.asInstanceOf[Generate].generator.isInstanceOf[JsonTuple])
+ }
+
+ test("transform query spec") {
+ val plan1 = parser.parsePlan("select transform(a, b) using 'func' from e where f < 10")
+ .asInstanceOf[ScriptTransformation].copy(ioschema = null)
+ val plan2 = parser.parsePlan("map a, b using 'func' as c, d from e")
+ .asInstanceOf[ScriptTransformation].copy(ioschema = null)
+ val plan3 = parser.parsePlan("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e")
+ .asInstanceOf[ScriptTransformation].copy(ioschema = null)
+
+ val p = ScriptTransformation(
+ Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")),
+ "func", Seq.empty, plans.table("e"), null)
+
+ comparePlans(plan1,
+ p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string)))
+ comparePlans(plan2,
+ p.copy(output = Seq('c.string, 'd.string)))
+ comparePlans(plan3,
+ p.copy(output = Seq('c.int, 'd.decimal(10, 0))))
+ }
+
+ test("use backticks in output of Script Transform") {
+ parser.parsePlan(
+ """SELECT `t`.`thing1`
+ |FROM (SELECT TRANSFORM (`parquet_t1`.`key`, `parquet_t1`.`value`)
+ |USING 'cat' AS (`thing1` int, `thing2` string) FROM `default`.`parquet_t1`) AS t
+ """.stripMargin)
+ }
+
+ test("use backticks in output of Generator") {
+ parser.parsePlan(
+ """
+ |SELECT `gentab2`.`gencol2`
+ |FROM `default`.`src`
+ |LATERAL VIEW explode(array(array(1, 2, 3))) `gentab1` AS `gencol1`
+ |LATERAL VIEW explode(`gentab1`.`gencol1`) `gentab2` AS `gencol2`
+ """.stripMargin)
+ }
+
+ test("use escaped backticks in output of Generator") {
+ parser.parsePlan(
+ """
+ |SELECT `gen``tab2`.`gen``col2`
+ |FROM `default`.`src`
+ |LATERAL VIEW explode(array(array(1, 2, 3))) `gen``tab1` AS `gen``col1`
+ |LATERAL VIEW explode(`gen``tab1`.`gen``col1`) `gen``tab2` AS `gen``col2`
+ """.stripMargin)
+ }
+
+ test("create table - basic") {
+ val query = "CREATE TABLE my_table (id int, name string)"
+ val (desc, allowExisting) = extractTableDesc(query)
+ assert(!allowExisting)
+ assert(desc.identifier.database.isEmpty)
+ assert(desc.identifier.table == "my_table")
+ assert(desc.tableType == CatalogTableType.MANAGED_TABLE)
+ assert(desc.schema == Seq(CatalogColumn("id", "int"), CatalogColumn("name", "string")))
+ assert(desc.partitionColumnNames.isEmpty)
+ assert(desc.sortColumnNames.isEmpty)
+ assert(desc.bucketColumnNames.isEmpty)
+ assert(desc.numBuckets == -1)
+ assert(desc.viewText.isEmpty)
+ assert(desc.viewOriginalText.isEmpty)
+ assert(desc.storage.locationUri.isEmpty)
+ assert(desc.storage.inputFormat ==
+ Some("org.apache.hadoop.mapred.TextInputFormat"))
+ assert(desc.storage.outputFormat ==
+ Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat"))
+ assert(desc.storage.serde.isEmpty)
+ assert(desc.storage.serdeProperties.isEmpty)
+ assert(desc.properties.isEmpty)
+ assert(desc.comment.isEmpty)
+ }
+
+ test("create table - with database name") {
+ val query = "CREATE TABLE dbx.my_table (id int, name string)"
+ val (desc, _) = extractTableDesc(query)
+ assert(desc.identifier.database == Some("dbx"))
+ assert(desc.identifier.table == "my_table")
+ }
+
+ test("create table - temporary") {
+ val query = "CREATE TEMPORARY TABLE tab1 (id int, name string)"
+ val e = intercept[ParseException] { parser.parsePlan(query) }
+ assert(e.message.contains("registerTempTable"))
+ }
+
+ test("create table - external") {
+ val query = "CREATE EXTERNAL TABLE tab1 (id int, name string)"
+ val (desc, _) = extractTableDesc(query)
+ assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE)
+ }
+
+ test("create table - if not exists") {
+ val query = "CREATE TABLE IF NOT EXISTS tab1 (id int, name string)"
+ val (_, allowExisting) = extractTableDesc(query)
+ assert(allowExisting)
+ }
+
+ test("create table - comment") {
+ val query = "CREATE TABLE my_table (id int, name string) COMMENT 'its hot as hell below'"
+ val (desc, _) = extractTableDesc(query)
+ assert(desc.comment == Some("its hot as hell below"))
+ }
+
+ test("create table - partitioned columns") {
+ val query = "CREATE TABLE my_table (id int, name string) PARTITIONED BY (month int)"
+ val (desc, _) = extractTableDesc(query)
+ assert(desc.schema == Seq(
+ CatalogColumn("id", "int"),
+ CatalogColumn("name", "string"),
+ CatalogColumn("month", "int")))
+ assert(desc.partitionColumnNames == Seq("month"))
+ }
+
+ test("create table - clustered by") {
+ val baseQuery = "CREATE TABLE my_table (id int, name string) CLUSTERED BY(id)"
+ val query1 = s"$baseQuery INTO 10 BUCKETS"
+ val query2 = s"$baseQuery SORTED BY(id) INTO 10 BUCKETS"
+ val e1 = intercept[ParseException] { parser.parsePlan(query1) }
+ val e2 = intercept[ParseException] { parser.parsePlan(query2) }
+ assert(e1.getMessage.contains("Operation not allowed"))
+ assert(e2.getMessage.contains("Operation not allowed"))
+ }
+
+ test("create table - skewed by") {
+ val baseQuery = "CREATE TABLE my_table (id int, name string) SKEWED BY"
+ val query1 = s"$baseQuery(id) ON (1, 10, 100)"
+ val query2 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z'))"
+ val query3 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z')) STORED AS DIRECTORIES"
+ val e1 = intercept[ParseException] { parser.parsePlan(query1) }
+ val e2 = intercept[ParseException] { parser.parsePlan(query2) }
+ val e3 = intercept[ParseException] { parser.parsePlan(query3) }
+ assert(e1.getMessage.contains("Operation not allowed"))
+ assert(e2.getMessage.contains("Operation not allowed"))
+ assert(e3.getMessage.contains("Operation not allowed"))
+ }
+
+ test("create table - row format") {
+ val baseQuery = "CREATE TABLE my_table (id int, name string) ROW FORMAT"
+ val query1 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff'"
+ val query2 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1')"
+ val query3 =
+ s"""
+ |$baseQuery DELIMITED FIELDS TERMINATED BY 'x' ESCAPED BY 'y'
+ |COLLECTION ITEMS TERMINATED BY 'a'
+ |MAP KEYS TERMINATED BY 'b'
+ |LINES TERMINATED BY '\n'
+ |NULL DEFINED AS 'c'
+ """.stripMargin
+ val (desc1, _) = extractTableDesc(query1)
+ val (desc2, _) = extractTableDesc(query2)
+ val (desc3, _) = extractTableDesc(query3)
+ assert(desc1.storage.serde == Some("org.apache.poof.serde.Baff"))
+ assert(desc1.storage.serdeProperties.isEmpty)
+ assert(desc2.storage.serde == Some("org.apache.poof.serde.Baff"))
+ assert(desc2.storage.serdeProperties == Map("k1" -> "v1"))
+ assert(desc3.storage.serdeProperties == Map(
+ "field.delim" -> "x",
+ "escape.delim" -> "y",
+ "serialization.format" -> "x",
+ "line.delim" -> "\n",
+ "colelction.delim" -> "a", // yes, it's a typo from Hive :)
+ "mapkey.delim" -> "b"))
+ }
+
+ test("create table - file format") {
+ val baseQuery = "CREATE TABLE my_table (id int, name string) STORED AS"
+ val query1 = s"$baseQuery INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput'"
+ val query2 = s"$baseQuery ORC"
+ val (desc1, _) = extractTableDesc(query1)
+ val (desc2, _) = extractTableDesc(query2)
+ assert(desc1.storage.inputFormat == Some("winput"))
+ assert(desc1.storage.outputFormat == Some("wowput"))
+ assert(desc1.storage.serde.isEmpty)
+ assert(desc2.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"))
+ assert(desc2.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat"))
+ assert(desc2.storage.serde == Some("org.apache.hadoop.hive.ql.io.orc.OrcSerde"))
+ }
+
+ test("create table - storage handler") {
+ val baseQuery = "CREATE TABLE my_table (id int, name string) STORED BY"
+ val query1 = s"$baseQuery 'org.papachi.StorageHandler'"
+ val query2 = s"$baseQuery 'org.mamachi.StorageHandler' WITH SERDEPROPERTIES ('k1'='v1')"
+ val e1 = intercept[ParseException] { parser.parsePlan(query1) }
+ val e2 = intercept[ParseException] { parser.parsePlan(query2) }
+ assert(e1.getMessage.contains("Operation not allowed"))
+ assert(e2.getMessage.contains("Operation not allowed"))
+ }
+
+ test("create table - location") {
+ val query = "CREATE TABLE my_table (id int, name string) LOCATION '/path/to/mars'"
+ val (desc, _) = extractTableDesc(query)
+ assert(desc.storage.locationUri == Some("/path/to/mars"))
+ }
+
+ test("create table - properties") {
+ val query = "CREATE TABLE my_table (id int, name string) TBLPROPERTIES ('k1'='v1', 'k2'='v2')"
+ val (desc, _) = extractTableDesc(query)
+ assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2"))
+ }
+
+ test("create table - everything!") {
+ val query =
+ """
+ |CREATE EXTERNAL TABLE IF NOT EXISTS dbx.my_table (id int, name string)
+ |COMMENT 'no comment'
+ |PARTITIONED BY (month int)
+ |ROW FORMAT SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1')
+ |STORED AS INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput'
+ |LOCATION '/path/to/mercury'
+ |TBLPROPERTIES ('k1'='v1', 'k2'='v2')
+ """.stripMargin
+ val (desc, allowExisting) = extractTableDesc(query)
+ assert(allowExisting)
+ assert(desc.identifier.database == Some("dbx"))
+ assert(desc.identifier.table == "my_table")
+ assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE)
+ assert(desc.schema == Seq(
+ CatalogColumn("id", "int"),
+ CatalogColumn("name", "string"),
+ CatalogColumn("month", "int")))
+ assert(desc.partitionColumnNames == Seq("month"))
+ assert(desc.sortColumnNames.isEmpty)
+ assert(desc.bucketColumnNames.isEmpty)
+ assert(desc.numBuckets == -1)
+ assert(desc.viewText.isEmpty)
+ assert(desc.viewOriginalText.isEmpty)
+ assert(desc.storage.locationUri == Some("/path/to/mercury"))
+ assert(desc.storage.inputFormat == Some("winput"))
+ assert(desc.storage.outputFormat == Some("wowput"))
+ assert(desc.storage.serde == Some("org.apache.poof.serde.Baff"))
+ assert(desc.storage.serdeProperties == Map("k1" -> "v1"))
+ assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2"))
+ assert(desc.comment == Some("no comment"))
+ }
+
+ test("create view -- basic") {
+ val v1 = "CREATE VIEW view1 AS SELECT * FROM tab1"
+ val (desc, exists) = extractTableDesc(v1)
+ assert(!exists)
+ assert(desc.identifier.database.isEmpty)
+ assert(desc.identifier.table == "view1")
+ assert(desc.tableType == CatalogTableType.VIRTUAL_VIEW)
+ assert(desc.storage.locationUri.isEmpty)
+ assert(desc.schema == Seq.empty[CatalogColumn])
+ assert(desc.viewText == Option("SELECT * FROM tab1"))
+ assert(desc.viewOriginalText == Option("SELECT * FROM tab1"))
+ assert(desc.storage.serdeProperties == Map())
+ assert(desc.storage.inputFormat.isEmpty)
+ assert(desc.storage.outputFormat.isEmpty)
+ assert(desc.storage.serde.isEmpty)
+ assert(desc.properties == Map())
+ }
+
+ test("create view - full") {
+ val v1 =
+ """
+ |CREATE OR REPLACE VIEW IF NOT EXISTS view1
+ |(col1, col3)
+ |COMMENT 'BLABLA'
+ |TBLPROPERTIES('prop1Key'="prop1Val")
+ |AS SELECT * FROM tab1
+ """.stripMargin
+ val (desc, exists) = extractTableDesc(v1)
+ assert(exists)
+ assert(desc.identifier.database.isEmpty)
+ assert(desc.identifier.table == "view1")
+ assert(desc.tableType == CatalogTableType.VIRTUAL_VIEW)
+ assert(desc.storage.locationUri.isEmpty)
+ assert(desc.schema ==
+ CatalogColumn("col1", null, nullable = true, None) ::
+ CatalogColumn("col3", null, nullable = true, None) :: Nil)
+ assert(desc.viewText == Option("SELECT * FROM tab1"))
+ assert(desc.viewOriginalText == Option("SELECT * FROM tab1"))
+ assert(desc.storage.serdeProperties == Map())
+ assert(desc.storage.inputFormat.isEmpty)
+ assert(desc.storage.outputFormat.isEmpty)
+ assert(desc.storage.serde.isEmpty)
+ assert(desc.properties == Map("prop1Key" -> "prop1Val"))
+ assert(desc.comment == Option("BLABLA"))
+ }
+
+ test("create view -- partitioned view") {
+ val v1 = "CREATE VIEW view1 partitioned on (ds, hr) as select * from srcpart"
+ intercept[ParseException] {
+ parser.parsePlan(v1).isInstanceOf[HiveNativeCommand]
+ }
+ }
+
+ test("MSCK repair table (not supported)") {
+ assertUnsupported("MSCK REPAIR TABLE tab1")
+ }
+
+ test("create table like") {
+ val v1 = "CREATE TABLE table1 LIKE table2"
+ val (target, source, exists) = parser.parsePlan(v1).collect {
+ case CreateTableLike(t, s, allowExisting) => (t, s, allowExisting)
+ }.head
+ assert(exists == false)
+ assert(target.database.isEmpty)
+ assert(target.table == "table1")
+ assert(source.database.isEmpty)
+ assert(source.table == "table2")
+
+ val v2 = "CREATE TABLE IF NOT EXISTS table1 LIKE table2"
+ val (target2, source2, exists2) = parser.parsePlan(v2).collect {
+ case CreateTableLike(t, s, allowExisting) => (t, s, allowExisting)
+ }.head
+ assert(exists2)
+ assert(target2.database.isEmpty)
+ assert(target2.table == "table1")
+ assert(source2.database.isEmpty)
+ assert(source2.table == "table2")
+ }
+
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala
index 427f5747a0..3334c16f0b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveCatalogSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala
@@ -26,9 +26,9 @@ import org.apache.spark.sql.hive.client.{HiveClient, IsolatedClientLoader}
import org.apache.spark.util.Utils
/**
- * Test suite for the [[HiveCatalog]].
+ * Test suite for the [[HiveExternalCatalog]].
*/
-class HiveCatalogSuite extends CatalogTestCases {
+class HiveExternalCatalogSuite extends CatalogTestCases {
private val client: HiveClient = {
IsolatedClientLoader.forVersion(
@@ -41,7 +41,7 @@ class HiveCatalogSuite extends CatalogTestCases {
protected override val utils: CatalogTestUtils = new CatalogTestUtils {
override val tableInputFormat: String = "org.apache.hadoop.mapred.SequenceFileInputFormat"
override val tableOutputFormat: String = "org.apache.hadoop.mapred.SequenceFileOutputFormat"
- override def newEmptyCatalog(): ExternalCatalog = new HiveCatalog(client)
+ override def newEmptyCatalog(): ExternalCatalog = new HiveExternalCatalog(client)
}
protected override def resetState(): Unit = client.reset()
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
index 6967395613..8648834f0d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
@@ -83,12 +83,12 @@ class DataSourceWithHiveMetastoreCatalogSuite
.saveAsTable("t")
}
- val hiveTable = sessionState.catalog.getTable(TableIdentifier("t", Some("default")))
+ val hiveTable = sessionState.catalog.getTableMetadata(TableIdentifier("t", Some("default")))
assert(hiveTable.storage.inputFormat === Some(inputFormat))
assert(hiveTable.storage.outputFormat === Some(outputFormat))
assert(hiveTable.storage.serde === Some(serde))
- assert(hiveTable.partitionColumns.isEmpty)
+ assert(hiveTable.partitionColumnNames.isEmpty)
assert(hiveTable.tableType === CatalogTableType.MANAGED_TABLE)
val columns = hiveTable.schema
@@ -114,7 +114,8 @@ class DataSourceWithHiveMetastoreCatalogSuite
.saveAsTable("t")
}
- val hiveTable = sessionState.catalog.getTable(TableIdentifier("t", Some("default")))
+ val hiveTable =
+ sessionState.catalog.getTableMetadata(TableIdentifier("t", Some("default")))
assert(hiveTable.storage.inputFormat === Some(inputFormat))
assert(hiveTable.storage.outputFormat === Some(outputFormat))
assert(hiveTable.storage.serde === Some(serde))
@@ -144,12 +145,13 @@ class DataSourceWithHiveMetastoreCatalogSuite
|AS SELECT 1 AS d1, "val_1" AS d2
""".stripMargin)
- val hiveTable = sessionState.catalog.getTable(TableIdentifier("t", Some("default")))
+ val hiveTable =
+ sessionState.catalog.getTableMetadata(TableIdentifier("t", Some("default")))
assert(hiveTable.storage.inputFormat === Some(inputFormat))
assert(hiveTable.storage.outputFormat === Some(outputFormat))
assert(hiveTable.storage.serde === Some(serde))
- assert(hiveTable.partitionColumns.isEmpty)
+ assert(hiveTable.partitionColumnNames.isEmpty)
assert(hiveTable.tableType === CatalogTableType.EXTERNAL_TABLE)
val columns = hiveTable.schema
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
deleted file mode 100644
index 1c775db9b6..0000000000
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
+++ /dev/null
@@ -1,231 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive
-
-import org.apache.hadoop.hive.serde.serdeConstants
-import org.scalatest.BeforeAndAfterAll
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable, CatalogTableType}
-import org.apache.spark.sql.catalyst.expressions.JsonTuple
-import org.apache.spark.sql.catalyst.parser.SimpleParserConf
-import org.apache.spark.sql.catalyst.plans.logical.Generate
-
-class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll {
- val parser = new HiveQl(SimpleParserConf())
-
- private def extractTableDesc(sql: String): (CatalogTable, Boolean) = {
- parser.parsePlan(sql).collect {
- case CreateTableAsSelect(desc, child, allowExisting) => (desc, allowExisting)
- }.head
- }
-
- test("Test CTAS #1") {
- val s1 =
- """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view
- |(viewTime INT,
- |userid BIGINT,
- |page_url STRING,
- |referrer_url STRING,
- |ip STRING COMMENT 'IP Address of the User',
- |country STRING COMMENT 'country of origination')
- |COMMENT 'This is the staging page view table'
- |PARTITIONED BY (dt STRING COMMENT 'date type', hour STRING COMMENT 'hour of the day')
- |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\054' STORED AS RCFILE
- |LOCATION '/user/external/page_view'
- |TBLPROPERTIES ('p1'='v1', 'p2'='v2')
- |AS SELECT * FROM src""".stripMargin
-
- val (desc, exists) = extractTableDesc(s1)
- assert(exists)
- assert(desc.name.database == Some("mydb"))
- assert(desc.name.table == "page_view")
- assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE)
- assert(desc.storage.locationUri == Some("/user/external/page_view"))
- assert(desc.schema ==
- CatalogColumn("viewtime", "int") ::
- CatalogColumn("userid", "bigint") ::
- CatalogColumn("page_url", "string") ::
- CatalogColumn("referrer_url", "string") ::
- CatalogColumn("ip", "string", comment = Some("IP Address of the User")) ::
- CatalogColumn("country", "string", comment = Some("country of origination")) :: Nil)
- // TODO will be SQLText
- assert(desc.viewText == Option("This is the staging page view table"))
- assert(desc.partitionColumns ==
- CatalogColumn("dt", "string", comment = Some("date type")) ::
- CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil)
- assert(desc.storage.serdeProperties ==
- Map((serdeConstants.SERIALIZATION_FORMAT, "\u002C"), (serdeConstants.FIELD_DELIM, "\u002C")))
- assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat"))
- assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat"))
- assert(desc.storage.serde ==
- Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe"))
- assert(desc.properties == Map(("p1", "v1"), ("p2", "v2")))
- }
-
- test("Test CTAS #2") {
- val s2 =
- """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view
- |(viewTime INT,
- |userid BIGINT,
- |page_url STRING,
- |referrer_url STRING,
- |ip STRING COMMENT 'IP Address of the User',
- |country STRING COMMENT 'country of origination')
- |COMMENT 'This is the staging page view table'
- |PARTITIONED BY (dt STRING COMMENT 'date type', hour STRING COMMENT 'hour of the day')
- |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe'
- | STORED AS
- | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat'
- | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat'
- |LOCATION '/user/external/page_view'
- |TBLPROPERTIES ('p1'='v1', 'p2'='v2')
- |AS SELECT * FROM src""".stripMargin
-
- val (desc, exists) = extractTableDesc(s2)
- assert(exists)
- assert(desc.name.database == Some("mydb"))
- assert(desc.name.table == "page_view")
- assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE)
- assert(desc.storage.locationUri == Some("/user/external/page_view"))
- assert(desc.schema ==
- CatalogColumn("viewtime", "int") ::
- CatalogColumn("userid", "bigint") ::
- CatalogColumn("page_url", "string") ::
- CatalogColumn("referrer_url", "string") ::
- CatalogColumn("ip", "string", comment = Some("IP Address of the User")) ::
- CatalogColumn("country", "string", comment = Some("country of origination")) :: Nil)
- // TODO will be SQLText
- assert(desc.viewText == Option("This is the staging page view table"))
- assert(desc.partitionColumns ==
- CatalogColumn("dt", "string", comment = Some("date type")) ::
- CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil)
- assert(desc.storage.serdeProperties == Map())
- assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat"))
- assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat"))
- assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe"))
- assert(desc.properties == Map(("p1", "v1"), ("p2", "v2")))
- }
-
- test("Test CTAS #3") {
- val s3 = """CREATE TABLE page_view AS SELECT * FROM src"""
- val (desc, exists) = extractTableDesc(s3)
- assert(exists == false)
- assert(desc.name.database == None)
- assert(desc.name.table == "page_view")
- assert(desc.tableType == CatalogTableType.MANAGED_TABLE)
- assert(desc.storage.locationUri == None)
- assert(desc.schema == Seq.empty[CatalogColumn])
- assert(desc.viewText == None) // TODO will be SQLText
- assert(desc.storage.serdeProperties == Map())
- assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat"))
- assert(desc.storage.outputFormat ==
- Some("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat"))
- assert(desc.storage.serde.isEmpty)
- assert(desc.properties == Map())
- }
-
- test("Test CTAS #4") {
- val s4 =
- """CREATE TABLE page_view
- |STORED BY 'storage.handler.class.name' AS SELECT * FROM src""".stripMargin
- intercept[AnalysisException] {
- extractTableDesc(s4)
- }
- }
-
- test("Test CTAS #5") {
- val s5 = """CREATE TABLE ctas2
- | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe"
- | WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2")
- | STORED AS RCFile
- | TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22")
- | AS
- | SELECT key, value
- | FROM src
- | ORDER BY key, value""".stripMargin
- val (desc, exists) = extractTableDesc(s5)
- assert(exists == false)
- assert(desc.name.database == None)
- assert(desc.name.table == "ctas2")
- assert(desc.tableType == CatalogTableType.MANAGED_TABLE)
- assert(desc.storage.locationUri == None)
- assert(desc.schema == Seq.empty[CatalogColumn])
- assert(desc.viewText == None) // TODO will be SQLText
- assert(desc.storage.serdeProperties == Map(("serde_p1" -> "p1"), ("serde_p2" -> "p2")))
- assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat"))
- assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat"))
- assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe"))
- assert(desc.properties == Map(("tbl_p1" -> "p11"), ("tbl_p2" -> "p22")))
- }
-
- test("Invalid interval term should throw AnalysisException") {
- def assertError(sql: String, errorMessage: String): Unit = {
- val e = intercept[AnalysisException] {
- parser.parsePlan(sql)
- }
- assert(e.getMessage.contains(errorMessage))
- }
- assertError("select interval '42-32' year to month",
- "month 32 outside range [0, 11]")
- assertError("select interval '5 49:12:15' day to second",
- "hour 49 outside range [0, 23]")
- assertError("select interval '.1111111111' second",
- "nanosecond 1111111111 outside range")
- }
-
- test("use native json_tuple instead of hive's UDTF in LATERAL VIEW") {
- val plan = parser.parsePlan(
- """
- |SELECT *
- |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test
- |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt AS a, b
- """.stripMargin)
-
- assert(plan.children.head.asInstanceOf[Generate].generator.isInstanceOf[JsonTuple])
- }
-
- test("use backticks in output of Script Transform") {
- val plan = parser.parsePlan(
- """SELECT `t`.`thing1`
- |FROM (SELECT TRANSFORM (`parquet_t1`.`key`, `parquet_t1`.`value`)
- |USING 'cat' AS (`thing1` int, `thing2` string) FROM `default`.`parquet_t1`) AS t
- """.stripMargin)
- }
-
- test("use backticks in output of Generator") {
- val plan = parser.parsePlan(
- """
- |SELECT `gentab2`.`gencol2`
- |FROM `default`.`src`
- |LATERAL VIEW explode(array(array(1, 2, 3))) `gentab1` AS `gencol1`
- |LATERAL VIEW explode(`gentab1`.`gencol1`) `gentab2` AS `gencol2`
- """.stripMargin)
- }
-
- test("use escaped backticks in output of Generator") {
- val plan = parser.parsePlan(
- """
- |SELECT `gen``tab2`.`gen``col2`
- |FROM `default`.`src`
- |LATERAL VIEW explode(array(array(1, 2, 3))) `gen``tab1` AS `gen``col1`
- |LATERAL VIEW explode(`gen``tab1`.`gen``col1`) `gen``tab2` AS `gen``col2`
- """.stripMargin)
- }
-}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
index 16747cab37..c5417b06a4 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
@@ -31,7 +31,9 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark._
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{QueryTest, SQLContext}
+import org.apache.spark.sql.{QueryTest, Row, SQLContext}
+import org.apache.spark.sql.catalyst.catalog.CatalogFunction
+import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext}
import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer
@@ -55,6 +57,57 @@ class HiveSparkSubmitSuite
System.setProperty("spark.testing", "true")
}
+ test("temporary Hive UDF: define a UDF and use it") {
+ val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
+ val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA"))
+ val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB"))
+ val jarsString = Seq(jar1, jar2).map(j => j.toString).mkString(",")
+ val args = Seq(
+ "--class", TemporaryHiveUDFTest.getClass.getName.stripSuffix("$"),
+ "--name", "TemporaryHiveUDFTest",
+ "--master", "local-cluster[2,1,1024]",
+ "--conf", "spark.ui.enabled=false",
+ "--conf", "spark.master.rest.enabled=false",
+ "--driver-java-options", "-Dderby.system.durability=test",
+ "--jars", jarsString,
+ unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB")
+ runSparkSubmit(args)
+ }
+
+ test("permanent Hive UDF: define a UDF and use it") {
+ val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
+ val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA"))
+ val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB"))
+ val jarsString = Seq(jar1, jar2).map(j => j.toString).mkString(",")
+ val args = Seq(
+ "--class", PermanentHiveUDFTest1.getClass.getName.stripSuffix("$"),
+ "--name", "PermanentHiveUDFTest1",
+ "--master", "local-cluster[2,1,1024]",
+ "--conf", "spark.ui.enabled=false",
+ "--conf", "spark.master.rest.enabled=false",
+ "--driver-java-options", "-Dderby.system.durability=test",
+ "--jars", jarsString,
+ unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB")
+ runSparkSubmit(args)
+ }
+
+ test("permanent Hive UDF: use a already defined permanent function") {
+ val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
+ val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA"))
+ val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB"))
+ val jarsString = Seq(jar1, jar2).map(j => j.toString).mkString(",")
+ val args = Seq(
+ "--class", PermanentHiveUDFTest2.getClass.getName.stripSuffix("$"),
+ "--name", "PermanentHiveUDFTest2",
+ "--master", "local-cluster[2,1,1024]",
+ "--conf", "spark.ui.enabled=false",
+ "--conf", "spark.master.rest.enabled=false",
+ "--driver-java-options", "-Dderby.system.durability=test",
+ "--jars", jarsString,
+ unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB")
+ runSparkSubmit(args)
+ }
+
test("SPARK-8368: includes jars passed in through --jars") {
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA"))
@@ -135,6 +188,19 @@ class HiveSparkSubmitSuite
runSparkSubmit(args)
}
+ test("SPARK-14244 fix window partition size attribute binding failure") {
+ val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
+ val args = Seq(
+ "--class", SPARK_14244.getClass.getName.stripSuffix("$"),
+ "--name", "SparkSQLConfTest",
+ "--master", "local-cluster[2,1,1024]",
+ "--conf", "spark.ui.enabled=false",
+ "--conf", "spark.master.rest.enabled=false",
+ "--driver-java-options", "-Dderby.system.durability=test",
+ unusedJar.toString)
+ runSparkSubmit(args)
+ }
+
// NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly.
// This is copied from org.apache.spark.deploy.SparkSubmitSuite
private def runSparkSubmit(args: Seq[String]): Unit = {
@@ -195,6 +261,118 @@ class HiveSparkSubmitSuite
}
}
+// This application is used to test defining a new Hive UDF (with an associated jar)
+// and use this UDF. We need to run this test in separate JVM to make sure we
+// can load the jar defined with the function.
+object TemporaryHiveUDFTest extends Logging {
+ def main(args: Array[String]) {
+ Utils.configTestLog4j("INFO")
+ val conf = new SparkConf()
+ conf.set("spark.ui.enabled", "false")
+ val sc = new SparkContext(conf)
+ val hiveContext = new TestHiveContext(sc)
+
+ // Load a Hive UDF from the jar.
+ logInfo("Registering a temporary Hive UDF provided in a jar.")
+ val jar = hiveContext.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath
+ hiveContext.sql(
+ s"""
+ |CREATE TEMPORARY FUNCTION example_max
+ |AS 'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax'
+ |USING JAR '$jar'
+ """.stripMargin)
+ val source =
+ hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val")
+ source.registerTempTable("sourceTable")
+ // Actually use the loaded UDF.
+ logInfo("Using the UDF.")
+ val result = hiveContext.sql(
+ "SELECT example_max(key) as key, val FROM sourceTable GROUP BY val")
+ logInfo("Running a simple query on the table.")
+ val count = result.orderBy("key", "val").count()
+ if (count != 10) {
+ throw new Exception(s"Result table should have 10 rows instead of $count rows")
+ }
+ hiveContext.sql("DROP temporary FUNCTION example_max")
+ logInfo("Test finishes.")
+ sc.stop()
+ }
+}
+
+// This application is used to test defining a new Hive UDF (with an associated jar)
+// and use this UDF. We need to run this test in separate JVM to make sure we
+// can load the jar defined with the function.
+object PermanentHiveUDFTest1 extends Logging {
+ def main(args: Array[String]) {
+ Utils.configTestLog4j("INFO")
+ val conf = new SparkConf()
+ conf.set("spark.ui.enabled", "false")
+ val sc = new SparkContext(conf)
+ val hiveContext = new TestHiveContext(sc)
+
+ // Load a Hive UDF from the jar.
+ logInfo("Registering a permanent Hive UDF provided in a jar.")
+ val jar = hiveContext.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath
+ hiveContext.sql(
+ s"""
+ |CREATE FUNCTION example_max
+ |AS 'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax'
+ |USING JAR '$jar'
+ """.stripMargin)
+ val source =
+ hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val")
+ source.registerTempTable("sourceTable")
+ // Actually use the loaded UDF.
+ logInfo("Using the UDF.")
+ val result = hiveContext.sql(
+ "SELECT example_max(key) as key, val FROM sourceTable GROUP BY val")
+ logInfo("Running a simple query on the table.")
+ val count = result.orderBy("key", "val").count()
+ if (count != 10) {
+ throw new Exception(s"Result table should have 10 rows instead of $count rows")
+ }
+ hiveContext.sql("DROP FUNCTION example_max")
+ logInfo("Test finishes.")
+ sc.stop()
+ }
+}
+
+// This application is used to test that a pre-defined permanent function with a jar
+// resources can be used. We need to run this test in separate JVM to make sure we
+// can load the jar defined with the function.
+object PermanentHiveUDFTest2 extends Logging {
+ def main(args: Array[String]) {
+ Utils.configTestLog4j("INFO")
+ val conf = new SparkConf()
+ conf.set("spark.ui.enabled", "false")
+ val sc = new SparkContext(conf)
+ val hiveContext = new TestHiveContext(sc)
+ // Load a Hive UDF from the jar.
+ logInfo("Write the metadata of a permanent Hive UDF into metastore.")
+ val jar = hiveContext.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath
+ val function = CatalogFunction(
+ FunctionIdentifier("example_max"),
+ "org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax",
+ ("JAR" -> jar) :: Nil)
+ hiveContext.sessionState.catalog.createFunction(function, ignoreIfExists = false)
+ val source =
+ hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val")
+ source.registerTempTable("sourceTable")
+ // Actually use the loaded UDF.
+ logInfo("Using the UDF.")
+ val result = hiveContext.sql(
+ "SELECT example_max(key) as key, val FROM sourceTable GROUP BY val")
+ logInfo("Running a simple query on the table.")
+ val count = result.orderBy("key", "val").count()
+ if (count != 10) {
+ throw new Exception(s"Result table should have 10 rows instead of $count rows")
+ }
+ hiveContext.sql("DROP FUNCTION example_max")
+ logInfo("Test finishes.")
+ sc.stop()
+ }
+}
+
// This object is used for testing SPARK-8368: https://issues.apache.org/jira/browse/SPARK-8368.
// We test if we can load user jars in both driver and executors when HiveContext is used.
object SparkSubmitClassLoaderTest extends Logging {
@@ -378,3 +556,32 @@ object SPARK_11009 extends QueryTest {
}
}
}
+
+object SPARK_14244 extends QueryTest {
+ import org.apache.spark.sql.expressions.Window
+ import org.apache.spark.sql.functions._
+
+ protected var sqlContext: SQLContext = _
+
+ def main(args: Array[String]): Unit = {
+ Utils.configTestLog4j("INFO")
+
+ val sparkContext = new SparkContext(
+ new SparkConf()
+ .set("spark.ui.enabled", "false")
+ .set("spark.sql.shuffle.partitions", "100"))
+
+ val hiveContext = new TestHiveContext(sparkContext)
+ sqlContext = hiveContext
+
+ import hiveContext.implicits._
+
+ try {
+ val window = Window.orderBy('id)
+ val df = sqlContext.range(2).select(cume_dist().over(window).as('cdist)).orderBy('cdist)
+ checkAnswer(df, Seq(Row(0.5D), Row(1.0D)))
+ } finally {
+ sparkContext.stop()
+ }
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
index 40e9c9362c..4db95636e7 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
@@ -81,7 +81,7 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
test("Double create fails when allowExisting = false") {
sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)")
- intercept[QueryExecutionException] {
+ intercept[AnalysisException] {
sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)")
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala
index 5272f4192e..e8188e5f02 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala
@@ -34,7 +34,7 @@ class ListTablesSuite extends QueryTest with TestHiveSingleton with BeforeAndAft
super.beforeAll()
// The catalog in HiveContext is a case insensitive one.
sessionState.catalog.createTempTable(
- "ListTablesSuiteTable", df.logicalPlan, ignoreIfExists = true)
+ "ListTablesSuiteTable", df.logicalPlan, overrideIfExists = true)
sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)")
sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB")
sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index 71652897e6..3c299daa77 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -722,7 +722,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
withTable(tableName) {
val schema = StructType(StructField("int", IntegerType, true) :: Nil)
val hiveTable = CatalogTable(
- name = TableIdentifier(tableName, Some("default")),
+ identifier = TableIdentifier(tableName, Some("default")),
tableType = CatalogTableType.MANAGED_TABLE,
schema = Seq.empty,
storage = CatalogStorageFormat(
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
index ae026ed496..05318f51af 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
@@ -21,7 +21,6 @@ import scala.reflect.ClassTag
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.parser.SimpleParserConf
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.hive.execution._
import org.apache.spark.sql.hive.test.TestHiveSingleton
@@ -30,11 +29,9 @@ import org.apache.spark.sql.internal.SQLConf
class StatisticsSuite extends QueryTest with TestHiveSingleton {
import hiveContext.sql
- val parser = new HiveQl(SimpleParserConf())
-
test("parse analyze commands") {
def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) {
- val parsed = parser.parsePlan(analyzeCommand)
+ val parsed = HiveSqlParser.parsePlan(analyzeCommand)
val operators = parsed.collect {
case a: AnalyzeTable => a
case o => o
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
index 3ab4576811..d1aa5aa931 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
@@ -17,12 +17,51 @@
package org.apache.spark.sql.hive
-import org.apache.spark.sql.QueryTest
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.test.SQLTestUtils
case class FunctionResult(f1: String, f2: String)
-class UDFSuite extends QueryTest with TestHiveSingleton {
+/**
+ * A test suite for UDF related functionalities. Because Hive metastore is
+ * case insensitive, database names and function names have both upper case
+ * letters and lower case letters.
+ */
+class UDFSuite
+ extends QueryTest
+ with SQLTestUtils
+ with TestHiveSingleton
+ with BeforeAndAfterEach {
+
+ import hiveContext.implicits._
+
+ private[this] val functionName = "myUPper"
+ private[this] val functionNameUpper = "MYUPPER"
+ private[this] val functionNameLower = "myupper"
+
+ private[this] val functionClass =
+ classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDFUpper].getCanonicalName
+
+ private var testDF: DataFrame = null
+ private[this] val testTableName = "testDF_UDFSuite"
+ private var expectedDF: DataFrame = null
+
+ override def beforeAll(): Unit = {
+ sql("USE default")
+
+ testDF = (1 to 10).map(i => s"sTr$i").toDF("value")
+ testDF.registerTempTable(testTableName)
+ expectedDF = (1 to 10).map(i => s"STR$i").toDF("value")
+ super.beforeAll()
+ }
+
+ override def afterEach(): Unit = {
+ sql("USE default")
+ super.afterEach()
+ }
test("UDF case insensitive") {
hiveContext.udf.register("random0", () => { Math.random() })
@@ -32,4 +71,128 @@ class UDFSuite extends QueryTest with TestHiveSingleton {
assert(hiveContext.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0)
assert(hiveContext.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5)
}
+
+ test("temporary function: create and drop") {
+ withUserDefinedFunction(functionName -> true) {
+ intercept[AnalysisException] {
+ sql(s"CREATE TEMPORARY FUNCTION default.$functionName AS '$functionClass'")
+ }
+ sql(s"CREATE TEMPORARY FUNCTION $functionName AS '$functionClass'")
+ checkAnswer(
+ sql(s"SELECT $functionNameLower(value) from $testTableName"),
+ expectedDF
+ )
+ intercept[AnalysisException] {
+ sql(s"DROP TEMPORARY FUNCTION default.$functionName")
+ }
+ }
+ }
+
+ test("permanent function: create and drop without specifying db name") {
+ withUserDefinedFunction(functionName -> false) {
+ sql(s"CREATE FUNCTION $functionName AS '$functionClass'")
+ checkAnswer(
+ sql("SHOW functions like '.*upper'"),
+ Row(s"default.$functionNameLower")
+ )
+ checkAnswer(
+ sql(s"SELECT $functionName(value) from $testTableName"),
+ expectedDF
+ )
+ assert(
+ sql("SHOW functions").collect()
+ .map(_.getString(0))
+ .contains(s"default.$functionNameLower"))
+ }
+ }
+
+ test("permanent function: create and drop with a db name") {
+ // For this block, drop function command uses functionName as the function name.
+ withUserDefinedFunction(functionNameUpper -> false) {
+ sql(s"CREATE FUNCTION default.$functionName AS '$functionClass'")
+ // TODO: Re-enable it after can distinguish qualified and unqualified function name
+ // in SessionCatalog.lookupFunction.
+ // checkAnswer(
+ // sql(s"SELECT default.myuPPer(value) from $testTableName"),
+ // expectedDF
+ // )
+ checkAnswer(
+ sql(s"SELECT $functionName(value) from $testTableName"),
+ expectedDF
+ )
+ checkAnswer(
+ sql(s"SELECT default.$functionName(value) from $testTableName"),
+ expectedDF
+ )
+ }
+
+ // For this block, drop function command uses default.functionName as the function name.
+ withUserDefinedFunction(s"DEfault.$functionNameLower" -> false) {
+ sql(s"CREATE FUNCTION dEFault.$functionName AS '$functionClass'")
+ checkAnswer(
+ sql(s"SELECT $functionNameUpper(value) from $testTableName"),
+ expectedDF
+ )
+ }
+ }
+
+ test("permanent function: create and drop a function in another db") {
+ // For this block, drop function command uses functionName as the function name.
+ withTempDatabase { dbName =>
+ withUserDefinedFunction(functionName -> false) {
+ sql(s"CREATE FUNCTION $dbName.$functionName AS '$functionClass'")
+ // TODO: Re-enable it after can distinguish qualified and unqualified function name
+ // checkAnswer(
+ // sql(s"SELECT $dbName.myuPPer(value) from $testTableName"),
+ // expectedDF
+ // )
+
+ checkAnswer(
+ sql(s"SHOW FUNCTIONS like $dbName.$functionNameUpper"),
+ Row(s"$dbName.$functionNameLower")
+ )
+
+ sql(s"USE $dbName")
+
+ checkAnswer(
+ sql(s"SELECT $functionName(value) from $testTableName"),
+ expectedDF
+ )
+
+ sql(s"USE default")
+
+ checkAnswer(
+ sql(s"SELECT $dbName.$functionName(value) from $testTableName"),
+ expectedDF
+ )
+
+ sql(s"USE $dbName")
+ }
+
+ sql(s"USE default")
+
+ // For this block, drop function command uses default.functionName as the function name.
+ withUserDefinedFunction(s"$dbName.$functionNameUpper" -> false) {
+ sql(s"CREATE FUNCTION $dbName.$functionName AS '$functionClass'")
+ // TODO: Re-enable it after can distinguish qualified and unqualified function name
+ // checkAnswer(
+ // sql(s"SELECT $dbName.myupper(value) from $testTableName"),
+ // expectedDF
+ // )
+
+ sql(s"USE $dbName")
+
+ assert(
+ sql("SHOW functions").collect()
+ .map(_.getString(0))
+ .contains(s"$dbName.$functionNameLower"))
+ checkAnswer(
+ sql(s"SELECT $functionNameLower(value) from $testTableName"),
+ expectedDF
+ )
+
+ sql(s"USE default")
+ }
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
index d59bca4c7e..8b0719209d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
@@ -148,7 +148,7 @@ class VersionsSuite extends SparkFunSuite with Logging {
test(s"$version: createTable") {
val table =
CatalogTable(
- name = TableIdentifier("src", Some("default")),
+ identifier = TableIdentifier("src", Some("default")),
tableType = CatalogTableType.MANAGED_TABLE,
schema = Seq(CatalogColumn("key", "int")),
storage = CatalogStorageFormat(
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala
new file mode 100644
index 0000000000..061d1512a5
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.test.SQLTestUtils
+
+class HiveCommandSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
+ protected override def beforeAll(): Unit = {
+ super.beforeAll()
+ sql(
+ """
+ |CREATE TABLE parquet_tab1 (c1 INT, c2 STRING)
+ |USING org.apache.spark.sql.parquet.DefaultSource
+ """.stripMargin)
+
+ sql(
+ """
+ |CREATE EXTERNAL TABLE parquet_tab2 (c1 INT, c2 STRING)
+ |STORED AS PARQUET
+ |TBLPROPERTIES('prop1Key'="prop1Val", '`prop2Key`'="prop2Val")
+ """.stripMargin)
+ }
+
+ override protected def afterAll(): Unit = {
+ try {
+ sql("DROP TABLE IF EXISTS parquet_tab1")
+ sql("DROP TABLE IF EXISTS parquet_tab2")
+ } finally {
+ super.afterAll()
+ }
+ }
+
+ test("show tables") {
+ withTable("show1a", "show2b") {
+ sql("CREATE TABLE show1a(c1 int)")
+ sql("CREATE TABLE show2b(c2 int)")
+ checkAnswer(
+ sql("SHOW TABLES IN default 'show1*'"),
+ Row("show1a", false) :: Nil)
+ checkAnswer(
+ sql("SHOW TABLES IN default 'show1*|show2*'"),
+ Row("show1a", false) ::
+ Row("show2b", false) :: Nil)
+ checkAnswer(
+ sql("SHOW TABLES 'show1*|show2*'"),
+ Row("show1a", false) ::
+ Row("show2b", false) :: Nil)
+ assert(
+ sql("SHOW TABLES").count() >= 2)
+ assert(
+ sql("SHOW TABLES IN default").count() >= 2)
+ }
+ }
+
+ test("show tblproperties of data source tables - basic") {
+ checkAnswer(
+ sql("SHOW TBLPROPERTIES parquet_tab1")
+ .filter(s"key = 'spark.sql.sources.provider'"),
+ Row("spark.sql.sources.provider", "org.apache.spark.sql.parquet.DefaultSource") :: Nil
+ )
+
+ checkAnswer(
+ sql("SHOW TBLPROPERTIES parquet_tab1(spark.sql.sources.provider)"),
+ Row("org.apache.spark.sql.parquet.DefaultSource") :: Nil
+ )
+
+ checkAnswer(
+ sql("SHOW TBLPROPERTIES parquet_tab1")
+ .filter(s"key = 'spark.sql.sources.schema.numParts'"),
+ Row("spark.sql.sources.schema.numParts", "1") :: Nil
+ )
+
+ checkAnswer(
+ sql("SHOW TBLPROPERTIES parquet_tab1('spark.sql.sources.schema.numParts')"),
+ Row("1"))
+ }
+
+ test("show tblproperties for datasource table - errors") {
+ val message1 = intercept[AnalysisException] {
+ sql("SHOW TBLPROPERTIES badtable")
+ }.getMessage
+ assert(message1.contains("Table or View badtable not found in database default"))
+
+ // When key is not found, a row containing the error is returned.
+ checkAnswer(
+ sql("SHOW TBLPROPERTIES parquet_tab1('invalid.prop.key')"),
+ Row("Table default.parquet_tab1 does not have property: invalid.prop.key") :: Nil
+ )
+ }
+
+ test("show tblproperties for hive table") {
+ checkAnswer(sql("SHOW TBLPROPERTIES parquet_tab2('prop1Key')"), Row("prop1Val"))
+ checkAnswer(sql("SHOW TBLPROPERTIES parquet_tab2('`prop2Key`')"), Row("prop2Val"))
+ }
+
+ test("show tblproperties for spark temporary table - empty row") {
+ withTempTable("parquet_temp") {
+ sql(
+ """
+ |CREATE TEMPORARY TABLE parquet_temp (c1 INT, c2 STRING)
+ |USING org.apache.spark.sql.parquet.DefaultSource
+ """.stripMargin)
+
+ // An empty sequence of row is returned for session temporary table.
+ checkAnswer(sql("SHOW TBLPROPERTIES parquet_temp"), Nil)
+ }
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
index cfca93bbf0..e67fcbedc3 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
@@ -480,7 +480,11 @@ abstract class HiveComparisonTest
val executions = queryList.map(new TestHive.QueryExecution(_))
executions.foreach(_.toRdd)
val tablesGenerated = queryList.zip(executions).flatMap {
- case (q, e) => e.sparkPlan.collect {
+ // We should take executedPlan instead of sparkPlan, because in following codes we
+ // will run the collected plans. As we will do extra processing for sparkPlan such
+ // as adding exchange, collapsing codegen stages, etc., collecting sparkPlan here
+ // will cause some errors when running these plans later.
+ case (q, e) => e.executedPlan.collect {
case i: InsertIntoHiveTable if tablesRead contains i.table.tableName =>
(q, e, i)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
new file mode 100644
index 0000000000..206d911e0d
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import java.io.File
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode}
+import org.apache.spark.sql.catalyst.catalog.CatalogTableType
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SQLTestUtils
+
+class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
+ import hiveContext.implicits._
+
+ // check if the directory for recording the data of the table exists.
+ private def tableDirectoryExists(tableIdentifier: TableIdentifier): Boolean = {
+ val expectedTablePath =
+ hiveContext.sessionState.catalog.hiveDefaultTableFilePath(tableIdentifier)
+ val filesystemPath = new Path(expectedTablePath)
+ val fs = filesystemPath.getFileSystem(sparkContext.hadoopConfiguration)
+ fs.exists(filesystemPath)
+ }
+
+ test("drop tables") {
+ withTable("tab1") {
+ val tabName = "tab1"
+
+ assert(!tableDirectoryExists(TableIdentifier(tabName)))
+ sql(s"CREATE TABLE $tabName(c1 int)")
+
+ assert(tableDirectoryExists(TableIdentifier(tabName)))
+ sql(s"DROP TABLE $tabName")
+
+ assert(!tableDirectoryExists(TableIdentifier(tabName)))
+ sql(s"DROP TABLE IF EXISTS $tabName")
+ sql(s"DROP VIEW IF EXISTS $tabName")
+ }
+ }
+
+ test("drop managed tables") {
+ withTempDir { tmpDir =>
+ val tabName = "tab1"
+ withTable(tabName) {
+ assert(tmpDir.listFiles.isEmpty)
+ sql(
+ s"""
+ |create table $tabName
+ |stored as parquet
+ |location '$tmpDir'
+ |as select 1, '3'
+ """.stripMargin)
+
+ val hiveTable =
+ hiveContext.sessionState.catalog
+ .getTableMetadata(TableIdentifier(tabName, Some("default")))
+ // It is a managed table, although it uses external in SQL
+ assert(hiveTable.tableType == CatalogTableType.MANAGED_TABLE)
+
+ assert(tmpDir.listFiles.nonEmpty)
+ sql(s"DROP TABLE $tabName")
+ // The data are deleted since the table type is not EXTERNAL
+ assert(tmpDir.listFiles == null)
+ }
+ }
+ }
+
+ test("drop external data source table") {
+ withTempDir { tmpDir =>
+ val tabName = "tab1"
+ withTable(tabName) {
+ assert(tmpDir.listFiles.isEmpty)
+
+ withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") {
+ Seq(1 -> "a").toDF("i", "j")
+ .write
+ .mode(SaveMode.Overwrite)
+ .format("parquet")
+ .option("path", tmpDir.toString)
+ .saveAsTable(tabName)
+ }
+
+ val hiveTable =
+ hiveContext.sessionState.catalog
+ .getTableMetadata(TableIdentifier(tabName, Some("default")))
+ // This data source table is external table
+ assert(hiveTable.tableType == CatalogTableType.EXTERNAL_TABLE)
+
+ assert(tmpDir.listFiles.nonEmpty)
+ sql(s"DROP TABLE $tabName")
+ // The data are not deleted since the table type is EXTERNAL
+ assert(tmpDir.listFiles.nonEmpty)
+ }
+ }
+ }
+
+ test("create table and view with comment") {
+ val catalog = hiveContext.sessionState.catalog
+ val tabName = "tab1"
+ withTable(tabName) {
+ sql(s"CREATE TABLE $tabName(c1 int) COMMENT 'BLABLA'")
+ val viewName = "view1"
+ withView(viewName) {
+ sql(s"CREATE VIEW $viewName COMMENT 'no comment' AS SELECT * FROM $tabName")
+ val tableMetadata = catalog.getTableMetadata(TableIdentifier(tabName, Some("default")))
+ val viewMetadata = catalog.getTableMetadata(TableIdentifier(viewName, Some("default")))
+ assert(tableMetadata.properties.get("comment") == Option("BLABLA"))
+ assert(viewMetadata.properties.get("comment") == Option("no comment"))
+ }
+ }
+ }
+
+ test("add/drop partitions - external table") {
+ val catalog = hiveContext.sessionState.catalog
+ withTempDir { tmpDir =>
+ val basePath = tmpDir.getCanonicalPath
+ val partitionPath_1stCol_part1 = new File(basePath + "/ds=2008-04-08")
+ val partitionPath_1stCol_part2 = new File(basePath + "/ds=2008-04-09")
+ val partitionPath_part1 = new File(basePath + "/ds=2008-04-08/hr=11")
+ val partitionPath_part2 = new File(basePath + "/ds=2008-04-09/hr=11")
+ val partitionPath_part3 = new File(basePath + "/ds=2008-04-08/hr=12")
+ val partitionPath_part4 = new File(basePath + "/ds=2008-04-09/hr=12")
+ val dirSet =
+ tmpDir :: partitionPath_1stCol_part1 :: partitionPath_1stCol_part2 ::
+ partitionPath_part1 :: partitionPath_part2 :: partitionPath_part3 ::
+ partitionPath_part4 :: Nil
+
+ val externalTab = "extTable_with_partitions"
+ withTable(externalTab) {
+ assert(tmpDir.listFiles.isEmpty)
+ sql(
+ s"""
+ |CREATE EXTERNAL TABLE $externalTab (key INT, value STRING)
+ |PARTITIONED BY (ds STRING, hr STRING)
+ |LOCATION '$basePath'
+ """.stripMargin)
+
+ // Before data insertion, all the directory are empty
+ assert(dirSet.forall(dir => dir.listFiles == null || dir.listFiles.isEmpty))
+
+ for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) {
+ sql(
+ s"""
+ |INSERT OVERWRITE TABLE $externalTab
+ |partition (ds='$ds',hr='$hr')
+ |SELECT 1, 'a'
+ """.stripMargin)
+ }
+
+ val hiveTable = catalog.getTableMetadata(TableIdentifier(externalTab, Some("default")))
+ assert(hiveTable.tableType == CatalogTableType.EXTERNAL_TABLE)
+ // After data insertion, all the directory are not empty
+ assert(dirSet.forall(dir => dir.listFiles.nonEmpty))
+
+ sql(
+ s"""
+ |ALTER TABLE $externalTab DROP PARTITION (ds='2008-04-08'),
+ |PARTITION (ds='2008-04-09', hr='12')
+ """.stripMargin)
+ assert(catalog.listPartitions(TableIdentifier(externalTab)).map(_.spec).toSet ==
+ Set(Map("ds" -> "2008-04-09", "hr" -> "11")))
+ // drop partition will not delete the data of external table
+ assert(dirSet.forall(dir => dir.listFiles.nonEmpty))
+
+ sql(s"ALTER TABLE $externalTab ADD PARTITION (ds='2008-04-08', hr='12')")
+ assert(catalog.listPartitions(TableIdentifier(externalTab)).map(_.spec).toSet ==
+ Set(Map("ds" -> "2008-04-08", "hr" -> "12"), Map("ds" -> "2008-04-09", "hr" -> "11")))
+ // add partition will not delete the data
+ assert(dirSet.forall(dir => dir.listFiles.nonEmpty))
+
+ sql(s"DROP TABLE $externalTab")
+ // drop table will not delete the data of external table
+ assert(dirSet.forall(dir => dir.listFiles.nonEmpty))
+ }
+ }
+ }
+
+ test("drop views") {
+ withTable("tab1") {
+ val tabName = "tab1"
+ sqlContext.range(10).write.saveAsTable("tab1")
+ withView("view1") {
+ val viewName = "view1"
+
+ assert(tableDirectoryExists(TableIdentifier(tabName)))
+ assert(!tableDirectoryExists(TableIdentifier(viewName)))
+ sql(s"CREATE VIEW $viewName AS SELECT * FROM tab1")
+
+ assert(tableDirectoryExists(TableIdentifier(tabName)))
+ assert(!tableDirectoryExists(TableIdentifier(viewName)))
+ sql(s"DROP VIEW $viewName")
+
+ assert(tableDirectoryExists(TableIdentifier(tabName)))
+ sql(s"DROP VIEW IF EXISTS $viewName")
+ }
+ }
+ }
+
+ test("alter views - rename") {
+ val tabName = "tab1"
+ withTable(tabName) {
+ sqlContext.range(10).write.saveAsTable(tabName)
+ val oldViewName = "view1"
+ val newViewName = "view2"
+ withView(oldViewName, newViewName) {
+ val catalog = hiveContext.sessionState.catalog
+ sql(s"CREATE VIEW $oldViewName AS SELECT * FROM $tabName")
+
+ assert(catalog.tableExists(TableIdentifier(oldViewName)))
+ assert(!catalog.tableExists(TableIdentifier(newViewName)))
+ sql(s"ALTER VIEW $oldViewName RENAME TO $newViewName")
+ assert(!catalog.tableExists(TableIdentifier(oldViewName)))
+ assert(catalog.tableExists(TableIdentifier(newViewName)))
+ }
+ }
+ }
+
+ test("alter views - set/unset tblproperties") {
+ val tabName = "tab1"
+ withTable(tabName) {
+ sqlContext.range(10).write.saveAsTable(tabName)
+ val viewName = "view1"
+ withView(viewName) {
+ val catalog = hiveContext.sessionState.catalog
+ sql(s"CREATE VIEW $viewName AS SELECT * FROM $tabName")
+
+ assert(catalog.getTableMetadata(TableIdentifier(viewName))
+ .properties.filter(_._1 != "transient_lastDdlTime") == Map())
+ sql(s"ALTER VIEW $viewName SET TBLPROPERTIES ('p' = 'an')")
+ assert(catalog.getTableMetadata(TableIdentifier(viewName))
+ .properties.filter(_._1 != "transient_lastDdlTime") == Map("p" -> "an"))
+
+ // no exception or message will be issued if we set it again
+ sql(s"ALTER VIEW $viewName SET TBLPROPERTIES ('p' = 'an')")
+ assert(catalog.getTableMetadata(TableIdentifier(viewName))
+ .properties.filter(_._1 != "transient_lastDdlTime") == Map("p" -> "an"))
+
+ // the value will be updated if we set the same key to a different value
+ sql(s"ALTER VIEW $viewName SET TBLPROPERTIES ('p' = 'b')")
+ assert(catalog.getTableMetadata(TableIdentifier(viewName))
+ .properties.filter(_._1 != "transient_lastDdlTime") == Map("p" -> "b"))
+
+ sql(s"ALTER VIEW $viewName UNSET TBLPROPERTIES ('p')")
+ assert(catalog.getTableMetadata(TableIdentifier(viewName))
+ .properties.filter(_._1 != "transient_lastDdlTime") == Map())
+
+ val message = intercept[AnalysisException] {
+ sql(s"ALTER VIEW $viewName UNSET TBLPROPERTIES ('p')")
+ }.getMessage
+ assert(message.contains(
+ "attempted to unset non-existent property 'p' in table '`view1`'"))
+ }
+ }
+ }
+
+ test("alter views and alter table - misuse") {
+ val tabName = "tab1"
+ withTable(tabName) {
+ sqlContext.range(10).write.saveAsTable(tabName)
+ val oldViewName = "view1"
+ val newViewName = "view2"
+ withView(oldViewName, newViewName) {
+ val catalog = hiveContext.sessionState.catalog
+ sql(s"CREATE VIEW $oldViewName AS SELECT * FROM $tabName")
+
+ assert(catalog.tableExists(TableIdentifier(tabName)))
+ assert(catalog.tableExists(TableIdentifier(oldViewName)))
+
+ var message = intercept[AnalysisException] {
+ sql(s"ALTER VIEW $tabName RENAME TO $newViewName")
+ }.getMessage
+ assert(message.contains(
+ "Cannot alter a table with ALTER VIEW. Please use ALTER TABLE instead"))
+
+ message = intercept[AnalysisException] {
+ sql(s"ALTER VIEW $tabName SET TBLPROPERTIES ('p' = 'an')")
+ }.getMessage
+ assert(message.contains(
+ "Cannot alter a table with ALTER VIEW. Please use ALTER TABLE instead"))
+
+ message = intercept[AnalysisException] {
+ sql(s"ALTER VIEW $tabName UNSET TBLPROPERTIES ('p')")
+ }.getMessage
+ assert(message.contains(
+ "Cannot alter a table with ALTER VIEW. Please use ALTER TABLE instead"))
+
+ message = intercept[AnalysisException] {
+ sql(s"ALTER TABLE $oldViewName RENAME TO $newViewName")
+ }.getMessage
+ assert(message.contains(
+ "Cannot alter a view with ALTER TABLE. Please use ALTER VIEW instead"))
+
+ message = intercept[AnalysisException] {
+ sql(s"ALTER TABLE $oldViewName SET TBLPROPERTIES ('p' = 'an')")
+ }.getMessage
+ assert(message.contains(
+ "Cannot alter a view with ALTER TABLE. Please use ALTER VIEW instead"))
+
+ message = intercept[AnalysisException] {
+ sql(s"ALTER TABLE $oldViewName UNSET TBLPROPERTIES ('p')")
+ }.getMessage
+ assert(message.contains(
+ "Cannot alter a view with ALTER TABLE. Please use ALTER VIEW instead"))
+
+ assert(catalog.tableExists(TableIdentifier(tabName)))
+ assert(catalog.tableExists(TableIdentifier(oldViewName)))
+ }
+ }
+ }
+
+ test("drop table using drop view") {
+ withTable("tab1") {
+ sql("CREATE TABLE tab1(c1 int)")
+ val message = intercept[AnalysisException] {
+ sql("DROP VIEW tab1")
+ }.getMessage
+ assert(message.contains("Cannot drop a table with DROP VIEW. Please use DROP TABLE instead"))
+ }
+ }
+
+ test("drop view using drop table") {
+ withTable("tab1") {
+ sqlContext.range(10).write.saveAsTable("tab1")
+ withView("view1") {
+ sql("CREATE VIEW view1 AS SELECT * FROM tab1")
+ val message = intercept[AnalysisException] {
+ sql("DROP TABLE view1")
+ }.getMessage
+ assert(message.contains("Cannot drop a view with DROP TABLE. Please use DROP VIEW instead"))
+ }
+ }
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
index b7ef5d1db7..c45d49d6c0 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
@@ -101,4 +101,33 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto
"Physical Plan should not contain Subquery since it's eliminated by optimizer")
}
}
+
+ test("EXPLAIN CODEGEN command") {
+ checkExistence(sql("EXPLAIN CODEGEN SELECT 1"), true,
+ "WholeStageCodegen",
+ "Generated code:",
+ "/* 001 */ public Object generate(Object[] references) {",
+ "/* 002 */ return new GeneratedIterator(references);",
+ "/* 003 */ }"
+ )
+
+ checkExistence(sql("EXPLAIN CODEGEN SELECT 1"), false,
+ "== Physical Plan =="
+ )
+
+ checkExistence(sql("EXPLAIN EXTENDED CODEGEN SELECT 1"), true,
+ "WholeStageCodegen",
+ "Generated code:",
+ "/* 001 */ public Object generate(Object[] references) {",
+ "/* 002 */ return new GeneratedIterator(references);",
+ "/* 003 */ }"
+ )
+
+ checkExistence(sql("EXPLAIN EXTENDED CODEGEN SELECT 1"), false,
+ "== Parsed Logical Plan ==",
+ "== Analyzed Logical Plan ==",
+ "== Optimized Logical Plan ==",
+ "== Physical Plan =="
+ )
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 197a123905..af73baa1f3 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -28,7 +28,6 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.{SparkException, SparkFiles}
import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
-import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException
import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin
@@ -62,12 +61,17 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
TestHive.cacheTables = false
TimeZone.setDefault(originalTimeZone)
Locale.setDefault(originalLocale)
- sql("DROP TEMPORARY FUNCTION udtf_count2")
+ sql("DROP TEMPORARY FUNCTION IF EXISTS udtf_count2")
} finally {
super.afterAll()
}
}
+ private def assertUnsupportedFeature(body: => Unit): Unit = {
+ val e = intercept[AnalysisException] { body }
+ assert(e.getMessage.toLowerCase.contains("unsupported operation"))
+ }
+
test("SPARK-4908: concurrent hive native commands") {
(1 to 100).par.map { _ =>
sql("USE default")
@@ -702,7 +706,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
def isExplanation(result: DataFrame): Boolean = {
val explanation = result.select('plan).collect().map { case Row(plan: String) => plan }
- explanation.contains("== Physical Plan ==")
+ explanation.head.startsWith("== Physical Plan ==")
}
test("SPARK-1704: Explain commands as a DataFrame") {
@@ -1225,14 +1229,16 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
val e = intercept[AnalysisException] {
range(1).selectExpr("not_a_udf()")
}
- assert(e.getMessage.contains("undefined function not_a_udf"))
+ assert(e.getMessage.contains("Undefined function"))
+ assert(e.getMessage.contains("not_a_udf"))
var success = false
val t = new Thread("test") {
override def run(): Unit = {
val e = intercept[AnalysisException] {
range(1).selectExpr("not_a_udf()")
}
- assert(e.getMessage.contains("undefined function not_a_udf"))
+ assert(e.getMessage.contains("Undefined function"))
+ assert(e.getMessage.contains("not_a_udf"))
success = true
}
}
@@ -1246,6 +1252,57 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
// Put tests that depend on specific Hive settings before these last two test,
// since they modify /clear stuff.
+
+ test("role management commands are not supported") {
+ assertUnsupportedFeature { sql("CREATE ROLE my_role") }
+ assertUnsupportedFeature { sql("DROP ROLE my_role") }
+ assertUnsupportedFeature { sql("SHOW CURRENT ROLES") }
+ assertUnsupportedFeature { sql("SHOW ROLES") }
+ assertUnsupportedFeature { sql("SHOW GRANT") }
+ assertUnsupportedFeature { sql("SHOW ROLE GRANT USER my_principal") }
+ assertUnsupportedFeature { sql("SHOW PRINCIPALS my_role") }
+ assertUnsupportedFeature { sql("SET ROLE my_role") }
+ assertUnsupportedFeature { sql("GRANT my_role TO USER my_user") }
+ assertUnsupportedFeature { sql("GRANT ALL ON my_table TO USER my_user") }
+ assertUnsupportedFeature { sql("REVOKE my_role FROM USER my_user") }
+ assertUnsupportedFeature { sql("REVOKE ALL ON my_table FROM USER my_user") }
+ }
+
+ test("import/export commands are not supported") {
+ assertUnsupportedFeature { sql("IMPORT TABLE my_table FROM 'my_path'") }
+ assertUnsupportedFeature { sql("EXPORT TABLE my_table TO 'my_path'") }
+ }
+
+ test("some show commands are not supported") {
+ assertUnsupportedFeature { sql("SHOW CREATE TABLE my_table") }
+ assertUnsupportedFeature { sql("SHOW COMPACTIONS") }
+ assertUnsupportedFeature { sql("SHOW TRANSACTIONS") }
+ assertUnsupportedFeature { sql("SHOW INDEXES ON my_table") }
+ assertUnsupportedFeature { sql("SHOW LOCKS my_table") }
+ }
+
+ test("lock/unlock table and database commands are not supported") {
+ assertUnsupportedFeature { sql("LOCK TABLE my_table SHARED") }
+ assertUnsupportedFeature { sql("UNLOCK TABLE my_table") }
+ assertUnsupportedFeature { sql("LOCK DATABASE my_db SHARED") }
+ assertUnsupportedFeature { sql("UNLOCK DATABASE my_db") }
+ }
+
+ test("create/drop/alter index commands are not supported") {
+ assertUnsupportedFeature {
+ sql("CREATE INDEX my_index ON TABLE my_table(a) as 'COMPACT' WITH DEFERRED REBUILD")}
+ assertUnsupportedFeature { sql("DROP INDEX my_index ON my_table") }
+ assertUnsupportedFeature { sql("ALTER INDEX my_index ON my_table REBUILD")}
+ assertUnsupportedFeature {
+ sql("ALTER INDEX my_index ON my_table set IDXPROPERTIES (\"prop1\"=\"val1_new\")")}
+ }
+
+ test("create/drop macro commands are not supported") {
+ assertUnsupportedFeature {
+ sql("CREATE TEMPORARY MACRO SIGMOID (x DOUBLE) 1.0 / (1.0 + EXP(-x))")
+ }
+ assertUnsupportedFeature { sql("DROP TEMPORARY MACRO SIGMOID") }
+ }
}
// for SPARK-2180 test
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index b0e263dff9..d07ac56586 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -303,7 +303,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
val message = intercept[AnalysisException] {
sql("SELECT testUDFTwoListList() FROM testUDF")
}.getMessage
- assert(message.contains("No handler for Hive udf"))
+ assert(message.contains("No handler for Hive UDF"))
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList")
}
@@ -313,7 +313,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
val message = intercept[AnalysisException] {
sql("SELECT testUDFAnd() FROM testUDF")
}.getMessage
- assert(message.contains("No handler for Hive udf"))
+ assert(message.contains("No handler for Hive UDF"))
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFAnd")
}
@@ -323,7 +323,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
val message = intercept[AnalysisException] {
sql("SELECT testUDAFPercentile(a) FROM testUDF GROUP BY b")
}.getMessage
- assert(message.contains("No handler for Hive udf"))
+ assert(message.contains("No handler for Hive UDF"))
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDAFPercentile")
}
@@ -333,7 +333,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
val message = intercept[AnalysisException] {
sql("SELECT testUDAFAverage() FROM testUDF GROUP BY b")
}.getMessage
- assert(message.contains("No handler for Hive udf"))
+ assert(message.contains("No handler for Hive UDF"))
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDAFAverage")
}
@@ -343,7 +343,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
val message = intercept[AnalysisException] {
sql("SELECT testUDTFExplode() FROM testUDF")
}.getMessage
- assert(message.contains("No handler for Hive udf"))
+ assert(message.contains("No handler for Hive UDF"))
sql("DROP TEMPORARY FUNCTION IF EXISTS testUDTFExplode")
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala
index 37c01792d9..97cb9d9720 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala
@@ -149,7 +149,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter {
val (actualScannedColumns, actualPartValues) = plan.collect {
case p @ HiveTableScan(columns, relation, _) =>
val columnNames = columns.map(_.name)
- val partValues = if (relation.table.partitionColumns.nonEmpty) {
+ val partValues = if (relation.table.partitionColumnNames.nonEmpty) {
p.prunePartitions(relation.getHiveQlPartitions()).map(_.getValues)
} else {
Seq.empty
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 6199253d34..5ce16be4dc 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -19,8 +19,6 @@ package org.apache.spark.sql.hive.execution
import java.sql.{Date, Timestamp}
-import scala.collection.JavaConverters._
-
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry}
@@ -67,22 +65,43 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
import hiveContext.implicits._
test("UDTF") {
- sql(s"ADD JAR ${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath()}")
- // The function source code can be found at:
- // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF
- sql(
- """
- |CREATE TEMPORARY FUNCTION udtf_count2
- |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2'
- """.stripMargin)
+ withUserDefinedFunction("udtf_count2" -> true) {
+ sql(s"ADD JAR ${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath()}")
+ // The function source code can be found at:
+ // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF
+ sql(
+ """
+ |CREATE TEMPORARY FUNCTION udtf_count2
+ |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2'
+ """.stripMargin)
- checkAnswer(
- sql("SELECT key, cc FROM src LATERAL VIEW udtf_count2(value) dd AS cc"),
- Row(97, 500) :: Row(97, 500) :: Nil)
+ checkAnswer(
+ sql("SELECT key, cc FROM src LATERAL VIEW udtf_count2(value) dd AS cc"),
+ Row(97, 500) :: Row(97, 500) :: Nil)
- checkAnswer(
- sql("SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"),
- Row(3) :: Row(3) :: Nil)
+ checkAnswer(
+ sql("SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"),
+ Row(3) :: Row(3) :: Nil)
+ }
+ }
+
+ test("permanent UDTF") {
+ withUserDefinedFunction("udtf_count_temp" -> false) {
+ sql(
+ s"""
+ |CREATE FUNCTION udtf_count_temp
+ |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2'
+ |USING JAR '${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath()}'
+ """.stripMargin)
+
+ checkAnswer(
+ sql("SELECT key, cc FROM src LATERAL VIEW udtf_count_temp(value) dd AS cc"),
+ Row(97, 500) :: Row(97, 500) :: Nil)
+
+ checkAnswer(
+ sql("SELECT udtf_count_temp(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"),
+ Row(3) :: Row(3) :: Nil)
+ }
}
test("SPARK-6835: udtf in lateral view") {
@@ -169,9 +188,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
}
test("show functions") {
- val allBuiltinFunctions =
- (FunctionRegistry.builtin.listFunction().toSet[String] ++
- org.apache.hadoop.hive.ql.exec.FunctionRegistry.getFunctionNames.asScala).toList.sorted
+ val allBuiltinFunctions = FunctionRegistry.builtin.listFunction().toSet[String].toList.sorted
// The TestContext is shared by all the test cases, some functions may be registered before
// this, so we check that all the builtin functions are returned.
val allFunctions = sql("SHOW functions").collect().map(r => r(0))
@@ -185,9 +202,13 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
checkAnswer(sql("SHOW functions `abc`.`abs`"), Row("abs"))
checkAnswer(sql("SHOW functions `~`"), Row("~"))
checkAnswer(sql("SHOW functions `a function doens't exist`"), Nil)
- checkAnswer(sql("SHOW functions `weekofyea.*`"), Row("weekofyear"))
+ checkAnswer(sql("SHOW functions `weekofyea*`"), Row("weekofyear"))
// this probably will failed if we add more function with `sha` prefixing.
- checkAnswer(sql("SHOW functions `sha.*`"), Row("sha") :: Row("sha1") :: Row("sha2") :: Nil)
+ checkAnswer(sql("SHOW functions `sha*`"), Row("sha") :: Row("sha1") :: Row("sha2") :: Nil)
+ // Test '|' for alternation.
+ checkAnswer(
+ sql("SHOW functions 'sha*|weekofyea*'"),
+ Row("sha") :: Row("sha1") :: Row("sha2") :: Row("weekofyear") :: Nil)
}
test("describe functions") {
@@ -213,8 +234,26 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
checkExistence(sql("describe functioN `~`"), true,
"Function: ~",
- "Class: org.apache.hadoop.hive.ql.udf.UDFOPBitNot",
- "Usage: ~ n - Bitwise not")
+ "Class: org.apache.spark.sql.catalyst.expressions.BitwiseNot",
+ "Usage: ~ b - Bitwise NOT.")
+
+ // Hard coded describe functions
+ checkExistence(sql("describe function `<>`"), true,
+ "Function: <>",
+ "Usage: a <> b - Returns TRUE if a is not equal to b")
+
+ checkExistence(sql("describe function `!=`"), true,
+ "Function: !=",
+ "Usage: a != b - Returns TRUE if a is not equal to b")
+
+ checkExistence(sql("describe function `between`"), true,
+ "Function: between",
+ "Usage: a [NOT] BETWEEN b AND c - evaluate if a is [not] in between b and c")
+
+ checkExistence(sql("describe function `case`"), true,
+ "Function: case",
+ "Usage: CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END - " +
+ "When a = b, returns c; when a = d, return e; else return f")
}
test("SPARK-5371: union with null and sum") {
@@ -321,7 +360,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
var message = intercept[AnalysisException] {
sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value")
}.getMessage
- assert(message.contains("ctas1 already exists"))
+ assert(message.contains("already exists"))
checkRelation("ctas1", true)
sql("DROP TABLE ctas1")
@@ -1449,7 +1488,6 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
sql(
"""CREATE VIEW IF NOT EXISTS
|default.testView (c1 COMMENT 'blabla', c2 COMMENT 'blabla')
- |COMMENT 'blabla'
|TBLPROPERTIES ('a' = 'b')
|AS SELECT * FROM jt""".stripMargin)
checkAnswer(sql("SELECT c1, c2 FROM testView ORDER BY c1"), (1 to 9).map(i => Row(i, i)))
@@ -1811,4 +1849,50 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
}
}
}
+
+ test(
+ "SPARK-14488 \"CREATE TEMPORARY TABLE ... USING ... AS SELECT ...\" " +
+ "shouldn't create persisted table"
+ ) {
+ withTempPath { dir =>
+ withTempTable("t1", "t2") {
+ val path = dir.getCanonicalPath
+ val ds = sqlContext.range(10)
+ ds.registerTempTable("t1")
+
+ sql(
+ s"""CREATE TEMPORARY TABLE t2
+ |USING PARQUET
+ |OPTIONS (PATH '$path')
+ |AS SELECT * FROM t1
+ """.stripMargin)
+
+ checkAnswer(
+ sqlContext.tables().select('isTemporary).filter('tableName === "t2"),
+ Row(true)
+ )
+
+ checkAnswer(table("t2"), table("t1"))
+ }
+ }
+ }
+
+ test(
+ "SPARK-14493 \"CREATE TEMPORARY TABLE ... USING ... AS SELECT ...\" " +
+ "shouldn always be used together with PATH data source option"
+ ) {
+ withTempTable("t") {
+ sqlContext.range(10).registerTempTable("t")
+
+ val message = intercept[IllegalArgumentException] {
+ sql(
+ s"""CREATE TEMPORARY TABLE t1
+ |USING PARQUET
+ |AS SELECT * FROM t
+ """.stripMargin)
+ }.getMessage
+
+ assert(message == "'path' is not specified")
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
index 92f424bac7..5ef8194f28 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala
@@ -26,6 +26,8 @@ import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.sql.internal.SQLConf
@@ -400,4 +402,41 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
}
}
}
+
+ test("SPARK-14070 Use ORC data source for SQL queries on ORC tables") {
+ withTempPath { dir =>
+ withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true",
+ HiveContext.CONVERT_METASTORE_ORC.key -> "true") {
+ val path = dir.getCanonicalPath
+
+ withTable("dummy_orc") {
+ withTempTable("single") {
+ sqlContext.sql(
+ s"""CREATE TABLE dummy_orc(key INT, value STRING)
+ |STORED AS ORC
+ |LOCATION '$path'
+ """.stripMargin)
+
+ val singleRowDF = Seq((0, "foo")).toDF("key", "value").coalesce(1)
+ singleRowDF.registerTempTable("single")
+
+ sqlContext.sql(
+ s"""INSERT INTO TABLE dummy_orc
+ |SELECT key, value FROM single
+ """.stripMargin)
+
+ val df = sqlContext.sql("SELECT * FROM dummy_orc WHERE key=0")
+ checkAnswer(df, singleRowDF)
+
+ val queryExecution = df.queryExecution
+ queryExecution.analyzed.collectFirst {
+ case _: LogicalRelation => ()
+ }.getOrElse {
+ fail(s"Expecting the query plan to have LogicalRelation, but got:\n$queryExecution")
+ }
+ }
+ }
+ }
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
index b6fc61d453..eac65d5720 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
@@ -311,7 +311,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {
case ExecutedCommand(_: InsertIntoHadoopFsRelation) => // OK
case o => fail("test_insert_parquet should be converted to a " +
s"${classOf[HadoopFsRelation ].getCanonicalName} and " +
- s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan. " +
+ s"${classOf[InsertIntoDataSource].getCanonicalName} is expected as the SparkPlan. " +
s"However, found a ${o.toString} ")
}
@@ -341,7 +341,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {
case ExecutedCommand(_: InsertIntoHadoopFsRelation) => // OK
case o => fail("test_insert_parquet should be converted to a " +
s"${classOf[HadoopFsRelation ].getCanonicalName} and " +
- s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." +
+ s"${classOf[InsertIntoDataSource].getCanonicalName} is expected as the SparkPlan." +
s"However, found a ${o.toString} ")
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
index ea7e905742..10eeb30242 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
@@ -668,40 +668,6 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
df.write.format(dataSourceName).partitionBy("c", "d", "e").saveAsTable("t")
}
}
-
- test("SPARK-9899 Disable customized output committer when speculation is on") {
- val clonedConf = new Configuration(hadoopConfiguration)
- val speculationEnabled =
- sqlContext.sparkContext.conf.getBoolean("spark.speculation", defaultValue = false)
-
- try {
- withTempPath { dir =>
- // Enables task speculation
- sqlContext.sparkContext.conf.set("spark.speculation", "true")
-
- // Uses a customized output committer which always fails
- hadoopConfiguration.set(
- SQLConf.OUTPUT_COMMITTER_CLASS.key,
- classOf[AlwaysFailOutputCommitter].getName)
-
- // Code below shouldn't throw since customized output committer should be disabled.
- val df = sqlContext.range(10).toDF().coalesce(1)
- df.write.format(dataSourceName).save(dir.getCanonicalPath)
- checkAnswer(
- sqlContext
- .read
- .format(dataSourceName)
- .option("dataSchema", df.schema.json)
- .load(dir.getCanonicalPath),
- df)
- }
- } finally {
- // Hadoop 1 doesn't have `Configuration.unset`
- hadoopConfiguration.clear()
- clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue))
- sqlContext.sparkContext.conf.set("spark.speculation", speculationEnabled.toString)
- }
- }
}
// This class is used to test SPARK-8578. We should not use any custom output committer when
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index f9f3d97ef3..0395600954 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -247,10 +247,10 @@ class CheckpointWriter(
// Delete old checkpoint files
val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs))
if (allCheckpointFiles.size > 10) {
- allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach(file => {
+ allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach { file =>
logInfo("Deleting " + file)
fs.delete(file, true)
- })
+ }
}
// All done, print success
@@ -334,8 +334,7 @@ object CheckpointReader extends Logging {
ignoreReadError: Boolean = false): Option[Checkpoint] = {
val checkpointPath = new Path(checkpointDir)
- // TODO(rxin): Why is this a def?!
- def fs: FileSystem = checkpointPath.getFileSystem(hadoopConf)
+ val fs = checkpointPath.getFileSystem(hadoopConf)
// Try to find the checkpoint files
val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)).reverse
@@ -346,7 +345,7 @@ object CheckpointReader extends Logging {
// Try to read the checkpoint files in the order
logInfo("Checkpoint files found: " + checkpointFiles.mkString(","))
var readError: Exception = null
- checkpointFiles.foreach(file => {
+ checkpointFiles.foreach { file =>
logInfo("Attempting to load checkpoint from file " + file)
try {
val fis = fs.open(file)
@@ -359,7 +358,7 @@ object CheckpointReader extends Logging {
readError = e
logWarning("Error reading checkpoint from file " + file, e)
}
- })
+ }
// If none of checkpoint files could be read, then throw exception
if (!ignoreReadError) {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index 3a664c4f5c..928739a416 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -18,6 +18,7 @@
package org.apache.spark.streaming
import java.io.{InputStream, NotSerializableException}
+import java.util.Properties
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
import scala.collection.Map
@@ -25,6 +26,7 @@ import scala.collection.mutable.Queue
import scala.reflect.ClassTag
import scala.util.control.NonFatal
+import org.apache.commons.lang.SerializationUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.{BytesWritable, LongWritable, Text}
@@ -43,7 +45,7 @@ import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContextState._
import org.apache.spark.streaming.dstream._
import org.apache.spark.streaming.receiver.Receiver
-import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener}
+import org.apache.spark.streaming.scheduler.{ExecutorAllocationManager, JobScheduler, StreamingListener}
import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab}
import org.apache.spark.util.{CallSite, ShutdownHookManager, ThreadUtils, Utils}
@@ -106,7 +108,7 @@ class StreamingContext private[streaming] (
* HDFS compatible filesystems
*/
def this(path: String, hadoopConf: Configuration) =
- this(null, CheckpointReader.read(path, new SparkConf(), hadoopConf).get, null)
+ this(null, CheckpointReader.read(path, new SparkConf(), hadoopConf).orNull, null)
/**
* Recreate a StreamingContext from a checkpoint file.
@@ -122,17 +124,14 @@ class StreamingContext private[streaming] (
def this(path: String, sparkContext: SparkContext) = {
this(
sparkContext,
- CheckpointReader.read(path, sparkContext.conf, sparkContext.hadoopConfiguration).get,
+ CheckpointReader.read(path, sparkContext.conf, sparkContext.hadoopConfiguration).orNull,
null)
}
+ require(_sc != null || _cp != null,
+ "Spark Streaming cannot be initialized with both SparkContext and checkpoint as null")
- if (_sc == null && _cp == null) {
- throw new Exception("Spark Streaming cannot be initialized with " +
- "both SparkContext and checkpoint as null")
- }
-
- private[streaming] val isCheckpointPresent = (_cp != null)
+ private[streaming] val isCheckpointPresent: Boolean = _cp != null
private[streaming] val sc: SparkContext = {
if (_sc != null) {
@@ -201,6 +200,10 @@ class StreamingContext private[streaming] (
private val startSite = new AtomicReference[CallSite](null)
+ // Copy of thread-local properties from SparkContext. These properties will be set in all tasks
+ // submitted by this StreamingContext after start.
+ private[streaming] val savedProperties = new AtomicReference[Properties](new Properties)
+
private[streaming] def getStartSite(): CallSite = startSite.get()
private var shutdownHookRef: AnyRef = _
@@ -213,8 +216,8 @@ class StreamingContext private[streaming] (
def sparkContext: SparkContext = sc
/**
- * Set each DStreams in this context to remember RDDs it generated in the last given duration.
- * DStreams remember RDDs only for a limited duration of time and releases them for garbage
+ * Set each DStream in this context to remember RDDs it generated in the last given duration.
+ * DStreams remember RDDs only for a limited duration of time and release them for garbage
* collection. This method allows the developer to specify how long to remember the RDDs (
* if the developer wishes to query old data outside the DStream computation).
* @param duration Minimum duration that each DStream should remember its RDDs
@@ -282,13 +285,14 @@ class StreamingContext private[streaming] (
}
/**
- * Create a input stream from TCP source hostname:port. Data is received using
+ * Creates an input stream from TCP source hostname:port. Data is received using
* a TCP socket and the receive bytes is interpreted as UTF8 encoded `\n` delimited
* lines.
* @param hostname Hostname to connect to for receiving data
* @param port Port to connect to for receiving data
* @param storageLevel Storage level to use for storing the received objects
* (default: StorageLevel.MEMORY_AND_DISK_SER_2)
+ * @see [[socketStream]]
*/
def socketTextStream(
hostname: String,
@@ -299,7 +303,7 @@ class StreamingContext private[streaming] (
}
/**
- * Create a input stream from TCP source hostname:port. Data is received using
+ * Creates an input stream from TCP source hostname:port. Data is received using
* a TCP socket and the receive bytes it interpreted as object using the given
* converter.
* @param hostname Hostname to connect to for receiving data
@@ -496,9 +500,10 @@ class StreamingContext private[streaming] (
new TransformedDStream[T](dstreams, sparkContext.clean(transformFunc))
}
- /** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for
- * receiving system events related to streaming.
- */
+ /**
+ * Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for
+ * receiving system events related to streaming.
+ */
def addStreamingListener(streamingListener: StreamingListener) {
scheduler.listenerBus.addListener(streamingListener)
}
@@ -528,11 +533,12 @@ class StreamingContext private[streaming] (
}
}
- if (Utils.isDynamicAllocationEnabled(sc.conf)) {
+ if (Utils.isDynamicAllocationEnabled(sc.conf) ||
+ ExecutorAllocationManager.isDynamicAllocationEnabled(conf)) {
logWarning("Dynamic Allocation is enabled for this application. " +
"Enabling Dynamic allocation for Spark Streaming applications can cause data loss if " +
"Write Ahead Log is not enabled for non-replayable sources like Flume. " +
- "See the programming guide for details on how to enable the Write Ahead Log")
+ "See the programming guide for details on how to enable the Write Ahead Log.")
}
}
@@ -573,6 +579,8 @@ class StreamingContext private[streaming] (
sparkContext.setCallSite(startSite.get)
sparkContext.clearJobGroup()
sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false")
+ savedProperties.set(SerializationUtils.clone(
+ sparkContext.localProperties.get()).asInstanceOf[Properties])
scheduler.start()
}
state = StreamingContextState.ACTIVE
@@ -860,7 +868,7 @@ private class StreamingContextPythonHelper {
*/
def tryRecoverFromCheckpoint(checkpointPath: String): Option[StreamingContext] = {
val checkpointOption = CheckpointReader.read(
- checkpointPath, new SparkConf(), SparkHadoopUtil.get.conf, false)
+ checkpointPath, new SparkConf(), SparkHadoopUtil.get.conf, ignoreReadError = false)
checkpointOption.map(new StreamingContext(null, _, null))
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
index 05f4da6fac..922e4a5e4d 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
@@ -517,9 +517,10 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable {
ssc.remember(duration)
}
- /** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for
- * receiving system events related to streaming.
- */
+ /**
+ * Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for
+ * receiving system events related to streaming.
+ */
def addStreamingListener(streamingListener: StreamingListener) {
ssc.addStreamingListener(streamingListener)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala
index b5f86fe779..995470ec8d 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala
@@ -23,7 +23,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{StreamingContext, Time}
/**
- * An input stream that always returns the same RDD on each timestep. Useful for testing.
+ * An input stream that always returns the same RDD on each time step. Useful for testing.
*/
class ConstantInputDStream[T: ClassTag](_ssc: StreamingContext, rdd: RDD[T])
extends InputDStream[T](_ssc) {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
index eb7b64eaf4..58842f9c2f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
@@ -83,7 +83,7 @@ abstract class DStream[T: ClassTag] (
// RDDs generated, marked as private[streaming] so that testsuites can access it
@transient
- private[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] ()
+ private[streaming] var generatedRDDs = new HashMap[Time, RDD[T]]()
// Time zero for the DStream
private[streaming] var zeroTime: Time = null
@@ -269,7 +269,7 @@ abstract class DStream[T: ClassTag] (
checkpointDuration == null || rememberDuration > checkpointDuration,
s"The remember duration for ${this.getClass.getSimpleName} has been set to " +
s" $rememberDuration which is not more than the checkpoint interval" +
- s" ($checkpointDuration). Please set it to higher than $checkpointDuration."
+ s" ($checkpointDuration). Please set it to a value higher than $checkpointDuration."
)
dependencies.foreach(_.validateAtStart())
@@ -277,7 +277,7 @@ abstract class DStream[T: ClassTag] (
logInfo(s"Slide time = $slideDuration")
logInfo(s"Storage level = ${storageLevel.description}")
logInfo(s"Checkpoint interval = $checkpointDuration")
- logInfo(s"Remember duration = $rememberDuration")
+ logInfo(s"Remember interval = $rememberDuration")
logInfo(s"Initialized and validated $this")
}
@@ -429,13 +429,12 @@ abstract class DStream[T: ClassTag] (
*/
private[streaming] def generateJob(time: Time): Option[Job] = {
getOrCompute(time) match {
- case Some(rdd) => {
+ case Some(rdd) =>
val jobFunc = () => {
val emptyFunc = { (iterator: Iterator[T]) => {} }
context.sparkContext.runJob(rdd, emptyFunc)
}
Some(new Job(time, jobFunc))
- }
case None => None
}
}
@@ -535,7 +534,7 @@ abstract class DStream[T: ClassTag] (
private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException {
logDebug(s"${this.getClass().getSimpleName}.readObject used")
ois.defaultReadObject()
- generatedRDDs = new HashMap[Time, RDD[T]] ()
+ generatedRDDs = new HashMap[Time, RDD[T]]()
}
// =======================================================================
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala
index 365a6bc417..e73837eb96 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala
@@ -29,7 +29,7 @@ import org.apache.spark.streaming.Time
import org.apache.spark.util.Utils
private[streaming]
-class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T])
+class DStreamCheckpointData[T: ClassTag](dstream: DStream[T])
extends Serializable with Logging {
protected val data = new HashMap[Time, AnyRef]()
@@ -45,7 +45,7 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T])
/**
* Updates the checkpoint data of the DStream. This gets called every time
* the graph checkpoint is initiated. Default implementation records the
- * checkpoint files to which the generate RDDs of the DStream has been saved.
+ * checkpoint files at which the generated RDDs of the DStream have been saved.
*/
def update(time: Time) {
@@ -103,16 +103,15 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T])
/**
* Restore the checkpoint data. This gets called once when the DStream graph
- * (along with its DStreams) are being restored from a graph checkpoint file.
+ * (along with its output DStreams) is being restored from a graph checkpoint file.
* Default implementation restores the RDDs from their checkpoint files.
*/
def restore() {
// Create RDDs from the checkpoint data
currentCheckpointFiles.foreach {
- case(time, file) => {
+ case(time, file) =>
logInfo("Restoring checkpointed RDD for time " + time + " from file '" + file + "'")
dstream.generatedRDDs += ((time, dstream.context.sparkContext.checkpointFile[T](file)))
- }
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
index 7fba2e8ec0..36f50e04db 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
@@ -333,14 +333,13 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]](
override def restore() {
hadoopFiles.toSeq.sortBy(_._1)(Time.ordering).foreach {
- case (t, f) => {
+ case (t, f) =>
// Restore the metadata in both files and generatedRDDs
logInfo("Restoring files for time " + t + " - " +
f.mkString("[", ", ", "]") )
batchTimeToSelectedFiles.synchronized { batchTimeToSelectedFiles += ((t, f)) }
recentlySelectedFiles ++= f
generatedRDDs += ((t, filesToRDD(f)))
- }
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
index 0b6b191dbe..a3c125c306 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
@@ -39,7 +39,7 @@ import org.apache.spark.util.Utils
*
* @param _ssc Streaming context that will execute this input stream
*/
-abstract class InputDStream[T: ClassTag] (_ssc: StreamingContext)
+abstract class InputDStream[T: ClassTag](_ssc: StreamingContext)
extends DStream[T](_ssc) {
private[streaming] var lastValidTime: Time = null
@@ -90,8 +90,8 @@ abstract class InputDStream[T: ClassTag] (_ssc: StreamingContext)
} else {
// Time is valid, but check it it is more than lastValidTime
if (lastValidTime != null && time < lastValidTime) {
- logWarning("isTimeValid called with " + time + " where as last valid time is " +
- lastValidTime)
+ logWarning(s"isTimeValid called with $time whereas the last valid time " +
+ s"is $lastValidTime")
}
lastValidTime = time
true
@@ -107,8 +107,8 @@ abstract class InputDStream[T: ClassTag] (_ssc: StreamingContext)
}
/** Method called to start receiving data. Subclasses must implement this method. */
- def start()
+ def start(): Unit
/** Method called to stop receiving data. Subclasses must implement this method. */
- def stop()
+ def stop(): Unit
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala
index a9be2f213f..a9e93838b8 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala
@@ -87,7 +87,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag](
logDebug("Window time = " + windowDuration)
logDebug("Slide time = " + slideDuration)
- logDebug("ZeroTime = " + zeroTime)
+ logDebug("Zero time = " + zeroTime)
logDebug("Current window = " + currentWindow)
logDebug("Previous window = " + previousWindow)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
index 68eff89030..8efb09a8ce 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
@@ -48,11 +48,11 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
// and then apply the update function
val updateFuncLocal = updateFunc
val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => {
- val i = iterator.map(t => {
+ val i = iterator.map { t =>
val itr = t._2._2.iterator
val headOption = if (itr.hasNext) Some(itr.next()) else None
(t._1, t._2._1.toSeq, headOption)
- })
+ }
updateFuncLocal(i)
}
val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
@@ -65,14 +65,12 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
// Try to get the previous state RDD
getOrCompute(validTime - slideDuration) match {
- case Some(prevStateRDD) => { // If previous state RDD exists
-
+ case Some(prevStateRDD) => // If previous state RDD exists
// Try to get the parent RDD
parent.getOrCompute(validTime) match {
- case Some(parentRDD) => { // If parent RDD exists, then compute as usual
- computeUsingPreviousRDD (parentRDD, prevStateRDD)
- }
- case None => { // If parent RDD does not exist
+ case Some(parentRDD) => // If parent RDD exists, then compute as usual
+ computeUsingPreviousRDD(parentRDD, prevStateRDD)
+ case None => // If parent RDD does not exist
// Re-apply the update function to the old state RDD
val updateFuncLocal = updateFunc
@@ -82,41 +80,33 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
}
val stateRDD = prevStateRDD.mapPartitions(finalFunc, preservePartitioning)
Some(stateRDD)
- }
}
- }
-
- case None => { // If previous session RDD does not exist (first input data)
+ case None => // If previous session RDD does not exist (first input data)
// Try to get the parent RDD
parent.getOrCompute(validTime) match {
- case Some(parentRDD) => { // If parent RDD exists, then compute as usual
+ case Some(parentRDD) => // If parent RDD exists, then compute as usual
initialRDD match {
- case None => {
+ case None =>
// Define the function for the mapPartition operation on grouped RDD;
// first map the grouped tuple to tuples of required type,
// and then apply the update function
val updateFuncLocal = updateFunc
val finalFunc = (iterator: Iterator[(K, Iterable[V])]) => {
- updateFuncLocal (iterator.map (tuple => (tuple._1, tuple._2.toSeq, None)))
+ updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2.toSeq, None)))
}
- val groupedRDD = parentRDD.groupByKey (partitioner)
- val sessionRDD = groupedRDD.mapPartitions (finalFunc, preservePartitioning)
+ val groupedRDD = parentRDD.groupByKey(partitioner)
+ val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning)
// logDebug("Generating state RDD for time " + validTime + " (first)")
- Some (sessionRDD)
- }
- case Some (initialStateRDD) => {
+ Some(sessionRDD)
+ case Some(initialStateRDD) =>
computeUsingPreviousRDD(parentRDD, initialStateRDD)
- }
}
- }
- case None => { // If parent RDD does not exist, then nothing to do!
+ case None => // If parent RDD does not exist, then nothing to do!
// logDebug("Not generating state RDD (no previous state, no parent)")
None
- }
}
- }
}
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala
index c1846a31f6..d46c0a01e0 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
import org.apache.spark.SparkException
-import org.apache.spark.rdd.{RDD, UnionRDD}
+import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{Duration, Time}
private[streaming]
@@ -45,7 +45,7 @@ class UnionDStream[T: ClassTag](parents: Array[DStream[T]])
s" time $validTime")
}
if (rdds.nonEmpty) {
- Some(new UnionRDD(ssc.sc, rdds))
+ Some(ssc.sc.union(rdds))
} else {
None
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala
index ee50a8d024..fe0f875525 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala
@@ -19,7 +19,7 @@ package org.apache.spark.streaming.dstream
import scala.reflect.ClassTag
-import org.apache.spark.rdd.{PartitionerAwareUnionRDD, RDD, UnionRDD}
+import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming._
import org.apache.spark.streaming.Duration
@@ -63,13 +63,6 @@ class WindowedDStream[T: ClassTag](
override def compute(validTime: Time): Option[RDD[T]] = {
val currentWindow = new Interval(validTime - windowDuration + parent.slideDuration, validTime)
val rddsInWindow = parent.slice(currentWindow)
- val windowRDD = if (rddsInWindow.flatMap(_.partitioner).distinct.length == 1) {
- logDebug("Using partition aware union for windowing at " + validTime)
- new PartitionerAwareUnionRDD(ssc.sc, rddsInWindow)
- } else {
- logDebug("Using normal union for windowing at " + validTime)
- new UnionRDD(ssc.sc, rddsInWindow)
- }
- Some(windowRDD)
+ Some(ssc.sc.union(rddsInWindow))
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
index c56520b1e2..53fccd8d5e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
@@ -162,7 +162,8 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag](
logDebug(s"Stored partition data of $this into block manager with level $storageLevel")
dataRead.rewind()
}
- serializerManager.dataDeserialize(blockId, new ChunkedByteBuffer(dataRead))
+ serializerManager
+ .dataDeserializeStream(blockId, new ChunkedByteBuffer(dataRead).toInputStream())
.asInstanceOf[Iterator[T]]
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
index e42bea6ec6..4592e015ed 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
@@ -37,7 +37,7 @@ private[streaming] trait BlockGeneratorListener {
* that will be useful when a block is generated. Any long blocking operation in this callback
* will hurt the throughput.
*/
- def onAddData(data: Any, metadata: Any)
+ def onAddData(data: Any, metadata: Any): Unit
/**
* Called when a new block of data is generated by the block generator. The block generation
@@ -47,7 +47,7 @@ private[streaming] trait BlockGeneratorListener {
* be useful when the block has been successfully stored. Any long blocking operation in this
* callback will hurt the throughput.
*/
- def onGenerateBlock(blockId: StreamBlockId)
+ def onGenerateBlock(blockId: StreamBlockId): Unit
/**
* Called when a new block is ready to be pushed. Callers are supposed to store the block into
@@ -55,13 +55,13 @@ private[streaming] trait BlockGeneratorListener {
* thread, that is not synchronized with any other callbacks. Hence it is okay to do long
* blocking operation in this callback.
*/
- def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_])
+ def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]): Unit
/**
* Called when an error has occurred in the BlockGenerator. Can be called form many places
* so better to not do any long block operation in this callback.
*/
- def onError(message: String, throwable: Throwable)
+ def onError(message: String, throwable: Throwable): Unit
}
/**
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala
index b2189103a0..fbac4880bd 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala
@@ -22,17 +22,18 @@ import com.google.common.util.concurrent.{RateLimiter => GuavaRateLimiter}
import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
-/** Provides waitToPush() method to limit the rate at which receivers consume data.
- *
- * waitToPush method will block the thread if too many messages have been pushed too quickly,
- * and only return when a new message has been pushed. It assumes that only one message is
- * pushed at a time.
- *
- * The spark configuration spark.streaming.receiver.maxRate gives the maximum number of messages
- * per second that each receiver will accept.
- *
- * @param conf spark configuration
- */
+/**
+ * Provides waitToPush() method to limit the rate at which receivers consume data.
+ *
+ * waitToPush method will block the thread if too many messages have been pushed too quickly,
+ * and only return when a new message has been pushed. It assumes that only one message is
+ * pushed at a time.
+ *
+ * The spark configuration spark.streaming.receiver.maxRate gives the maximum number of messages
+ * per second that each receiver will accept.
+ *
+ * @param conf spark configuration
+ */
private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging {
// treated as an upper limit
@@ -52,7 +53,7 @@ private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging {
* Set the rate limit to `newRate`. The new rate will not exceed the maximum rate configured by
* {{{spark.streaming.receiver.maxRate}}}, even if `newRate` is higher than that.
*
- * @param newRate A new rate in events per second. It has no effect if it's 0 or negative.
+ * @param newRate A new rate in records per second. It has no effect if it's 0 or negative.
*/
private[receiver] def updateRate(newRate: Long): Unit =
if (newRate > 0) {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
index 85350ff658..7aea1c9b64 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
@@ -48,7 +48,7 @@ private[streaming] trait ReceivedBlockHandler {
def storeBlock(blockId: StreamBlockId, receivedBlock: ReceivedBlock): ReceivedBlockStoreResult
/** Cleanup old blocks older than the given threshold time */
- def cleanupOldBlocks(threshTime: Long)
+ def cleanupOldBlocks(threshTime: Long): Unit
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala
index 3376cd557d..5157ca62dc 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala
@@ -99,13 +99,13 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable
* (iii) `restart(...)` can be called to restart the receiver. This will call `onStop()`
* immediately, and then `onStart()` after a delay.
*/
- def onStart()
+ def onStart(): Unit
/**
* This method is called by the system when the receiver is stopped. All resources
* (threads, buffers, etc.) set up in `onStart()` must be cleaned up in this method.
*/
- def onStop()
+ def onStop(): Unit
/** Override this to specify a preferred location (hostname). */
def preferredLocation: Option[String] = None
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala
index e0fe8d2206..42fc84c19b 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala
@@ -70,28 +70,28 @@ private[streaming] abstract class ReceiverSupervisor(
@volatile private[streaming] var receiverState = Initialized
/** Push a single data item to backend data store. */
- def pushSingle(data: Any)
+ def pushSingle(data: Any): Unit
/** Store the bytes of received data as a data block into Spark's memory. */
def pushBytes(
bytes: ByteBuffer,
optionalMetadata: Option[Any],
optionalBlockId: Option[StreamBlockId]
- )
+ ): Unit
/** Store a iterator of received data as a data block into Spark's memory. */
def pushIterator(
iterator: Iterator[_],
optionalMetadata: Option[Any],
optionalBlockId: Option[StreamBlockId]
- )
+ ): Unit
/** Store an ArrayBuffer of received data as a data block into Spark's memory. */
def pushArrayBuffer(
arrayBuffer: ArrayBuffer[_],
optionalMetadata: Option[Any],
optionalBlockId: Option[StreamBlockId]
- )
+ ): Unit
/**
* Create a custom [[BlockGenerator]] that the receiver implementation can directly control
@@ -103,7 +103,7 @@ private[streaming] abstract class ReceiverSupervisor(
def createBlockGenerator(blockGeneratorListener: BlockGeneratorListener): BlockGenerator
/** Report errors. */
- def reportError(message: String, throwable: Throwable)
+ def reportError(message: String, throwable: Throwable): Unit
/**
* Called when supervisor is started.
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala
new file mode 100644
index 0000000000..f7b6584893
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.streaming.scheduler
+
+import scala.util.Random
+
+import org.apache.spark.{ExecutorAllocationClient, SparkConf}
+import org.apache.spark.internal.Logging
+import org.apache.spark.streaming.util.RecurringTimer
+import org.apache.spark.util.{Clock, Utils}
+
+/**
+ * Class that manages executor allocated to a StreamingContext, and dynamically request or kill
+ * executors based on the statistics of the streaming computation. This is different from the core
+ * dynamic allocation policy; the core policy relies on executors being idle for a while, but the
+ * micro-batch model of streaming prevents any particular executors from being idle for a long
+ * time. Instead, the measure of "idle-ness" needs to be based on the time taken to process
+ * each batch.
+ *
+ * At a high level, the policy implemented by this class is as follows:
+ * - Use StreamingListener interface get batch processing times of completed batches
+ * - Periodically take the average batch completion times and compare with the batch interval
+ * - If (avg. proc. time / batch interval) >= scaling up ratio, then request more executors.
+ * The number of executors requested is based on the ratio = (avg. proc. time / batch interval).
+ * - If (avg. proc. time / batch interval) <= scaling down ratio, then try to kill a executor that
+ * is not running a receiver.
+ *
+ * This features should ideally be used in conjunction with backpressure, as backpressure ensures
+ * system stability, while executors are being readjusted.
+ */
+private[streaming] class ExecutorAllocationManager(
+ client: ExecutorAllocationClient,
+ receiverTracker: ReceiverTracker,
+ conf: SparkConf,
+ batchDurationMs: Long,
+ clock: Clock) extends StreamingListener with Logging {
+
+ import ExecutorAllocationManager._
+
+ private val scalingIntervalSecs = conf.getTimeAsSeconds(
+ SCALING_INTERVAL_KEY,
+ s"${SCALING_INTERVAL_DEFAULT_SECS}s")
+ private val scalingUpRatio = conf.getDouble(SCALING_UP_RATIO_KEY, SCALING_UP_RATIO_DEFAULT)
+ private val scalingDownRatio = conf.getDouble(SCALING_DOWN_RATIO_KEY, SCALING_DOWN_RATIO_DEFAULT)
+ private val minNumExecutors = conf.getInt(
+ MIN_EXECUTORS_KEY,
+ math.max(1, receiverTracker.numReceivers))
+ private val maxNumExecutors = conf.getInt(MAX_EXECUTORS_KEY, Integer.MAX_VALUE)
+ private val timer = new RecurringTimer(clock, scalingIntervalSecs * 1000,
+ _ => manageAllocation(), "streaming-executor-allocation-manager")
+
+ @volatile private var batchProcTimeSum = 0L
+ @volatile private var batchProcTimeCount = 0
+
+ validateSettings()
+
+ def start(): Unit = {
+ timer.start()
+ logInfo(s"ExecutorAllocationManager started with " +
+ s"ratios = [$scalingUpRatio, $scalingDownRatio] and interval = $scalingIntervalSecs sec")
+ }
+
+ def stop(): Unit = {
+ timer.stop(interruptTimer = true)
+ logInfo("ExecutorAllocationManager stopped")
+ }
+
+ /**
+ * Manage executor allocation by requesting or killing executors based on the collected
+ * batch statistics.
+ */
+ private def manageAllocation(): Unit = synchronized {
+ logInfo(s"Managing executor allocation with ratios = [$scalingUpRatio, $scalingDownRatio]")
+ if (batchProcTimeCount > 0) {
+ val averageBatchProcTime = batchProcTimeSum / batchProcTimeCount
+ val ratio = averageBatchProcTime.toDouble / batchDurationMs
+ logInfo(s"Average: $averageBatchProcTime, ratio = $ratio" )
+ if (ratio >= scalingUpRatio) {
+ logDebug("Requesting executors")
+ val numNewExecutors = math.max(math.round(ratio).toInt, 1)
+ requestExecutors(numNewExecutors)
+ } else if (ratio <= scalingDownRatio) {
+ logDebug("Killing executors")
+ killExecutor()
+ }
+ }
+ batchProcTimeSum = 0
+ batchProcTimeCount = 0
+ }
+
+ /** Request the specified number of executors over the currently active one */
+ private def requestExecutors(numNewExecutors: Int): Unit = {
+ require(numNewExecutors >= 1)
+ val allExecIds = client.getExecutorIds()
+ logDebug(s"Executors (${allExecIds.size}) = ${allExecIds}")
+ val targetTotalExecutors =
+ math.max(math.min(maxNumExecutors, allExecIds.size + numNewExecutors), minNumExecutors)
+ client.requestTotalExecutors(targetTotalExecutors, 0, Map.empty)
+ logInfo(s"Requested total $targetTotalExecutors executors")
+ }
+
+ /** Kill an executor that is not running any receiver, if possible */
+ private def killExecutor(): Unit = {
+ val allExecIds = client.getExecutorIds()
+ logDebug(s"Executors (${allExecIds.size}) = ${allExecIds}")
+
+ if (allExecIds.nonEmpty && allExecIds.size > minNumExecutors) {
+ val execIdsWithReceivers = receiverTracker.allocatedExecutors.values.flatten.toSeq
+ logInfo(s"Executors with receivers (${execIdsWithReceivers.size}): ${execIdsWithReceivers}")
+
+ val removableExecIds = allExecIds.diff(execIdsWithReceivers)
+ logDebug(s"Removable executors (${removableExecIds.size}): ${removableExecIds}")
+ if (removableExecIds.nonEmpty) {
+ val execIdToRemove = removableExecIds(Random.nextInt(removableExecIds.size))
+ client.killExecutor(execIdToRemove)
+ logInfo(s"Requested to kill executor $execIdToRemove")
+ } else {
+ logInfo(s"No non-receiver executors to kill")
+ }
+ } else {
+ logInfo("No available executor to kill")
+ }
+ }
+
+ private def addBatchProcTime(timeMs: Long): Unit = synchronized {
+ batchProcTimeSum += timeMs
+ batchProcTimeCount += 1
+ logDebug(
+ s"Added batch processing time $timeMs, sum = $batchProcTimeSum, count = $batchProcTimeCount")
+ }
+
+ private def validateSettings(): Unit = {
+ require(
+ scalingIntervalSecs > 0,
+ s"Config $SCALING_INTERVAL_KEY must be more than 0")
+
+ require(
+ scalingUpRatio > 0,
+ s"Config $SCALING_UP_RATIO_KEY must be more than 0")
+
+ require(
+ scalingDownRatio > 0,
+ s"Config $SCALING_DOWN_RATIO_KEY must be more than 0")
+
+ require(
+ minNumExecutors > 0,
+ s"Config $MIN_EXECUTORS_KEY must be more than 0")
+
+ require(
+ maxNumExecutors > 0,
+ s"$MAX_EXECUTORS_KEY must be more than 0")
+
+ require(
+ scalingUpRatio > scalingDownRatio,
+ s"Config $SCALING_UP_RATIO_KEY must be more than config $SCALING_DOWN_RATIO_KEY")
+
+ if (conf.contains(MIN_EXECUTORS_KEY) && conf.contains(MAX_EXECUTORS_KEY)) {
+ require(
+ maxNumExecutors >= minNumExecutors,
+ s"Config $MAX_EXECUTORS_KEY must be more than config $MIN_EXECUTORS_KEY")
+ }
+ }
+
+ override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = {
+ logDebug("onBatchCompleted called: " + batchCompleted)
+ if (!batchCompleted.batchInfo.outputOperationInfos.values.exists(_.failureReason.nonEmpty)) {
+ batchCompleted.batchInfo.processingDelay.foreach(addBatchProcTime)
+ }
+ }
+}
+
+private[streaming] object ExecutorAllocationManager extends Logging {
+ val ENABLED_KEY = "spark.streaming.dynamicAllocation.enabled"
+
+ val SCALING_INTERVAL_KEY = "spark.streaming.dynamicAllocation.scalingInterval"
+ val SCALING_INTERVAL_DEFAULT_SECS = 60
+
+ val SCALING_UP_RATIO_KEY = "spark.streaming.dynamicAllocation.scalingUpRatio"
+ val SCALING_UP_RATIO_DEFAULT = 0.9
+
+ val SCALING_DOWN_RATIO_KEY = "spark.streaming.dynamicAllocation.scalingDownRatio"
+ val SCALING_DOWN_RATIO_DEFAULT = 0.3
+
+ val MIN_EXECUTORS_KEY = "spark.streaming.dynamicAllocation.minExecutors"
+
+ val MAX_EXECUTORS_KEY = "spark.streaming.dynamicAllocation.maxExecutors"
+
+ def isDynamicAllocationEnabled(conf: SparkConf): Boolean = {
+ val numExecutor = conf.getInt("spark.executor.instances", 0)
+ val streamingDynamicAllocationEnabled = conf.getBoolean(ENABLED_KEY, false)
+ if (numExecutor != 0 && streamingDynamicAllocationEnabled) {
+ throw new IllegalArgumentException(
+ "Dynamic Allocation for streaming cannot be enabled while spark.executor.instances is set.")
+ }
+ if (Utils.isDynamicAllocationEnabled(conf) && streamingDynamicAllocationEnabled) {
+ throw new IllegalArgumentException(
+ """
+ |Dynamic Allocation cannot be enabled for both streaming and core at the same time.
+ |Please disable core Dynamic Allocation by setting spark.dynamicAllocation.enabled to
+ |false to use Dynamic Allocation in streaming.
+ """.stripMargin)
+ }
+ val testing = conf.getBoolean("spark.streaming.dynamicAllocation.testing", false)
+ numExecutor == 0 && streamingDynamicAllocationEnabled && (!Utils.isLocalMaster(conf) || testing)
+ }
+
+ def createIfEnabled(
+ client: ExecutorAllocationClient,
+ receiverTracker: ReceiverTracker,
+ conf: SparkConf,
+ batchDurationMs: Long,
+ clock: Clock): Option[ExecutorAllocationManager] = {
+ if (isDynamicAllocationEnabled(conf)) {
+ Some(new ExecutorAllocationManager(client, receiverTracker, conf, batchDurationMs, clock))
+ } else None
+ }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index 86f069b0bd..307ff1f7ec 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -241,11 +241,6 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
/** Generate jobs and perform checkpoint for the given `time`. */
private def generateJobs(time: Time) {
- // Set the SparkEnv in this thread, so that job generation code can access the environment
- // Example: BlockRDDs are created in this thread, and it needs to access BlockManager
- // Update: This is probably redundant after threadlocal stuff in SparkEnv has been removed.
- SparkEnv.set(ssc.env)
-
// Checkpoint all RDDs marked for checkpointing to ensure their lineages are
// truncated periodically. Otherwise, we may run into stack overflows (SPARK-6847).
ssc.sparkContext.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, "true")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
index 61f9e0974c..ac18f73ea8 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
@@ -17,11 +17,14 @@
package org.apache.spark.streaming.scheduler
+import java.util.Properties
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import scala.collection.JavaConverters._
import scala.util.Failure
+import org.apache.commons.lang.SerializationUtils
+
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.{PairRDDFunctions, RDD}
import org.apache.spark.streaming._
@@ -57,6 +60,8 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
// A tracker to track all the input stream information as well as processed record number
var inputInfoTracker: InputInfoTracker = null
+ private var executorAllocationManager: Option[ExecutorAllocationManager] = None
+
private var eventLoop: EventLoop[JobSchedulerEvent] = null
def start(): Unit = synchronized {
@@ -79,8 +84,16 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
listenerBus.start()
receiverTracker = new ReceiverTracker(ssc)
inputInfoTracker = new InputInfoTracker(ssc)
+ executorAllocationManager = ExecutorAllocationManager.createIfEnabled(
+ ssc.sparkContext,
+ receiverTracker,
+ ssc.conf,
+ ssc.graph.batchDuration.milliseconds,
+ clock)
+ executorAllocationManager.foreach(ssc.addStreamingListener)
receiverTracker.start()
jobGenerator.start()
+ executorAllocationManager.foreach(_.start())
logInfo("Started JobScheduler")
}
@@ -93,6 +106,10 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
receiverTracker.stop(processAllReceivedData)
}
+ if (executorAllocationManager != null) {
+ executorAllocationManager.foreach(_.stop())
+ }
+
// Second, stop generating jobs. If it has to process all received data,
// then this will wait for all the processing through JobScheduler to be over.
jobGenerator.stop(processAllReceivedData)
@@ -200,7 +217,10 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
import JobScheduler._
def run() {
+ val oldProps = ssc.sparkContext.getLocalProperties
try {
+ ssc.sparkContext.setLocalProperties(
+ SerializationUtils.clone(ssc.savedProperties.get()).asInstanceOf[Properties])
val formattedTime = UIUtils.formatBatchTime(
job.time.milliseconds, ssc.graph.batchDuration.milliseconds, showYYYYMMSS = false)
val batchUrl = s"/streaming/batch/?id=${job.time.milliseconds}"
@@ -234,8 +254,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
// JobScheduler has been stopped.
}
} finally {
- ssc.sc.setLocalProperty(JobScheduler.BATCH_TIME_PROPERTY_KEY, null)
- ssc.sc.setLocalProperty(JobScheduler.OUTPUT_OP_ID_PROPERTY_KEY, null)
+ ssc.sparkContext.setLocalProperties(oldProps)
}
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
index 66d5ffb797..0baedaf275 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
@@ -21,9 +21,10 @@ import scala.collection.mutable.HashSet
import org.apache.spark.streaming.Time
-/** Class representing a set of Jobs
- * belong to the same batch.
- */
+/**
+ * Class representing a set of Jobs
+ * belong to the same batch.
+ */
private[streaming]
case class JobSet(
time: Time,
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
index 9c8e68b03d..5d9a8ac0d9 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
@@ -119,7 +119,7 @@ private[streaming] class ReceivedBlockTracker(
timeToAllocatedBlocks.put(batchTime, allocatedBlocks)
lastAllocatedBatchTime = batchTime
} else {
- logInfo(s"Possibly processed batch $batchTime need to be processed again in WAL recovery")
+ logInfo(s"Possibly processed batch $batchTime needs to be processed again in WAL recovery")
}
} else {
// This situation occurs when:
@@ -129,7 +129,7 @@ private[streaming] class ReceivedBlockTracker(
// 2. Slow checkpointing makes recovered batch time older than WAL recovered
// lastAllocatedBatchTime.
// This situation will only occurs in recovery time.
- logInfo(s"Possibly processed batch $batchTime need to be processed again in WAL recovery")
+ logInfo(s"Possibly processed batch $batchTime needs to be processed again in WAL recovery")
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
index b3ae287001..9aa2f0bbb9 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
@@ -92,6 +92,8 @@ private[streaming] case object AllReceiverIds extends ReceiverTrackerLocalMessag
private[streaming] case class UpdateReceiverRateLimit(streamUID: Int, newRate: Long)
extends ReceiverTrackerLocalMessage
+private[streaming] case object GetAllReceiverInfo extends ReceiverTrackerLocalMessage
+
/**
* This class manages the execution of the receivers of ReceiverInputDStreams. Instance of
* this class must be created after all input streams have been added and StreamingContext.start()
@@ -234,6 +236,26 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
}
}
+ /**
+ * Get the executors allocated to each receiver.
+ * @return a map containing receiver ids to optional executor ids.
+ */
+ def allocatedExecutors(): Map[Int, Option[String]] = synchronized {
+ if (isTrackerStarted) {
+ endpoint.askWithRetry[Map[Int, ReceiverTrackingInfo]](GetAllReceiverInfo).mapValues {
+ _.runningExecutor.map {
+ _.executorId
+ }
+ }
+ } else {
+ Map.empty
+ }
+ }
+
+ def numReceivers(): Int = {
+ receiverInputStreams.size
+ }
+
/** Register a receiver */
private def registerReceiver(
streamId: Int,
@@ -412,11 +434,11 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
* worker nodes as a parallel collection, and runs them.
*/
private def launchReceivers(): Unit = {
- val receivers = receiverInputStreams.map(nis => {
+ val receivers = receiverInputStreams.map { nis =>
val rcvr = nis.getReceiver()
rcvr.setReceiverId(nis.id)
rcvr
- })
+ }
runDummySparkJob()
@@ -506,9 +528,12 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
case DeregisterReceiver(streamId, message, error) =>
deregisterReceiver(streamId, message, error)
context.reply(true)
+
// Local messages
case AllReceiverIds =>
context.reply(receiverTrackingInfos.filter(_._2.state != ReceiverState.INACTIVE).keys.toSeq)
+ case GetAllReceiverInfo =>
+ context.reply(receiverTrackingInfos.toMap)
case StopAllReceivers =>
assert(isTrackerStopping || isTrackerStopped)
stopReceivers()
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala
index d7210f64fc..7b2ef6881d 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala
@@ -21,18 +21,20 @@ import org.apache.spark.SparkConf
import org.apache.spark.streaming.Duration
/**
- * A component that estimates the rate at wich an InputDStream should ingest
- * elements, based on updates at every batch completion.
+ * A component that estimates the rate at which an `InputDStream` should ingest
+ * records, based on updates at every batch completion.
+ *
+ * @see [[org.apache.spark.streaming.scheduler.RateController]]
*/
private[streaming] trait RateEstimator extends Serializable {
/**
- * Computes the number of elements the stream attached to this `RateEstimator`
+ * Computes the number of records the stream attached to this `RateEstimator`
* should ingest per second, given an update on the size and completion
* times of the latest batch.
*
- * @param time The timetamp of the current batch interval that just finished
- * @param elements The number of elements that were processed in this batch
+ * @param time The timestamp of the current batch interval that just finished
+ * @param elements The number of records that were processed in this batch
* @param processingDelay The time in ms that took for the job to complete
* @param schedulingDelay The time in ms that the job spent in the scheduling queue
*/
@@ -46,13 +48,13 @@ private[streaming] trait RateEstimator extends Serializable {
object RateEstimator {
/**
- * Return a new RateEstimator based on the value of `spark.streaming.RateEstimator`.
+ * Return a new `RateEstimator` based on the value of
+ * `spark.streaming.backpressure.rateEstimator`.
*
- * The only known estimator right now is `pid`.
+ * The only known and acceptable estimator right now is `pid`.
*
* @return An instance of RateEstimator
- * @throws IllegalArgumentException if there is a configured RateEstimator that doesn't match any
- * known estimators.
+ * @throws IllegalArgumentException if the configured RateEstimator is not `pid`.
*/
def create(conf: SparkConf, batchInterval: Duration): RateEstimator =
conf.get("spark.streaming.backpressure.rateEstimator", "pid") match {
@@ -64,6 +66,6 @@ object RateEstimator {
new PIDRateEstimator(batchInterval.milliseconds, proportional, integral, derived, minRate)
case estimator =>
- throw new IllegalArgumentException(s"Unkown rate estimator: $estimator")
+ throw new IllegalArgumentException(s"Unknown rate estimator: $estimator")
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala
index d339723427..c024b4ef7e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala
@@ -52,7 +52,7 @@ private[ui] abstract class BatchTableBase(tableId: String, batchInterval: Long)
protected def baseRow(batch: BatchUIData): Seq[Node] = {
val batchTime = batch.batchTime.milliseconds
val formattedBatchTime = UIUtils.formatBatchTime(batchTime, batchInterval)
- val eventCount = batch.numRecords
+ val numRecords = batch.numRecords
val schedulingDelay = batch.schedulingDelay
val formattedSchedulingDelay = schedulingDelay.map(SparkUIUtils.formatDuration).getOrElse("-")
val processingTime = batch.processingDelay
@@ -65,7 +65,7 @@ private[ui] abstract class BatchTableBase(tableId: String, batchInterval: Long)
{formattedBatchTime}
</a>
</td>
- <td sorttable_customkey={eventCount.toString}>{eventCount.toString} events</td>
+ <td sorttable_customkey={numRecords.toString}>{numRecords.toString} records</td>
<td sorttable_customkey={schedulingDelay.getOrElse(Long.MaxValue).toString}>
{formattedSchedulingDelay}
</td>
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala
index d6fcc582b9..c086df47d9 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala
@@ -28,7 +28,7 @@ import org.apache.spark.streaming.{StreamingContext, Time}
import org.apache.spark.streaming.scheduler._
private[streaming] class StreamingJobProgressListener(ssc: StreamingContext)
- extends StreamingListener with SparkListener {
+ extends SparkListener with StreamingListener {
private val waitingBatchUIData = new HashMap[Time, BatchUIData]
private val runningBatchUIData = new HashMap[Time, BatchUIData]
@@ -202,21 +202,21 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext)
def streamIds: Seq[Int] = ssc.graph.getInputStreams().map(_.id)
/**
- * Return all of the event rates for each InputDStream in each batch. The key of the return value
- * is the stream id, and the value is a sequence of batch time with its event rate.
+ * Return all of the record rates for each InputDStream in each batch. The key of the return value
+ * is the stream id, and the value is a sequence of batch time with its record rate.
*/
- def receivedEventRateWithBatchTime: Map[Int, Seq[(Long, Double)]] = synchronized {
+ def receivedRecordRateWithBatchTime: Map[Int, Seq[(Long, Double)]] = synchronized {
val _retainedBatches = retainedBatches
val latestBatches = _retainedBatches.map { batchUIData =>
(batchUIData.batchTime.milliseconds, batchUIData.streamIdToInputInfo.mapValues(_.numRecords))
}
streamIds.map { streamId =>
- val eventRates = latestBatches.map {
+ val recordRates = latestBatches.map {
case (batchTime, streamIdToNumRecords) =>
val numRecords = streamIdToNumRecords.getOrElse(streamId, 0L)
(batchTime, numRecords * 1000.0 / batchDuration)
}
- (streamId, eventRates)
+ (streamId, recordRates)
}.toMap
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala
index fa40436221..b97e24f28b 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala
@@ -125,9 +125,9 @@ private[ui] class MillisecondsStatUIData(data: Seq[(Long, Long)]) {
* A helper class for "input rate" to generate data that will be used in the timeline and histogram
* graphs.
*
- * @param data (batchTime, event-rate).
+ * @param data (batch time, record rate).
*/
-private[ui] class EventRateUIData(val data: Seq[(Long, Double)]) {
+private[ui] class RecordRateUIData(val data: Seq[(Long, Double)]) {
val avg: Option[Double] = if (data.isEmpty) None else Some(data.map(_._2).sum / data.size)
@@ -215,7 +215,7 @@ private[ui] class StreamingPage(parent: StreamingTab)
val minBatchTime = if (batchTimes.isEmpty) startTime else batchTimes.min
val maxBatchTime = if (batchTimes.isEmpty) startTime else batchTimes.max
- val eventRateForAllStreams = new EventRateUIData(batches.map { batchInfo =>
+ val recordRateForAllStreams = new RecordRateUIData(batches.map { batchInfo =>
(batchInfo.batchTime.milliseconds, batchInfo.numRecords * 1000.0 / listener.batchDuration)
})
@@ -241,24 +241,24 @@ private[ui] class StreamingPage(parent: StreamingTab)
// Use the max input rate for all InputDStreams' graphs to make the Y axis ranges same.
// If it's not an integral number, just use its ceil integral number.
- val maxEventRate = eventRateForAllStreams.max.map(_.ceil.toLong).getOrElse(0L)
- val minEventRate = 0L
+ val maxRecordRate = recordRateForAllStreams.max.map(_.ceil.toLong).getOrElse(0L)
+ val minRecordRate = 0L
val batchInterval = UIUtils.convertToTimeUnit(listener.batchDuration, normalizedUnit)
val jsCollector = new JsCollector
- val graphUIDataForEventRateOfAllStreams =
+ val graphUIDataForRecordRateOfAllStreams =
new GraphUIData(
- "all-stream-events-timeline",
- "all-stream-events-histogram",
- eventRateForAllStreams.data,
+ "all-stream-records-timeline",
+ "all-stream-records-histogram",
+ recordRateForAllStreams.data,
minBatchTime,
maxBatchTime,
- minEventRate,
- maxEventRate,
- "events/sec")
- graphUIDataForEventRateOfAllStreams.generateDataJs(jsCollector)
+ minRecordRate,
+ maxRecordRate,
+ "records/sec")
+ graphUIDataForRecordRateOfAllStreams.generateDataJs(jsCollector)
val graphUIDataForSchedulingDelay =
new GraphUIData(
@@ -334,16 +334,16 @@ private[ui] class StreamingPage(parent: StreamingTab)
<div>Receivers: {listener.numActiveReceivers} / {numReceivers} active</div>
}
}
- <div>Avg: {eventRateForAllStreams.formattedAvg} events/sec</div>
+ <div>Avg: {recordRateForAllStreams.formattedAvg} records/sec</div>
</div>
</td>
- <td class="timeline">{graphUIDataForEventRateOfAllStreams.generateTimelineHtml(jsCollector)}</td>
- <td class="histogram">{graphUIDataForEventRateOfAllStreams.generateHistogramHtml(jsCollector)}</td>
+ <td class="timeline">{graphUIDataForRecordRateOfAllStreams.generateTimelineHtml(jsCollector)}</td>
+ <td class="histogram">{graphUIDataForRecordRateOfAllStreams.generateHistogramHtml(jsCollector)}</td>
</tr>
{if (hasStream) {
<tr id="inputs-table" style="display: none;" >
<td colspan="3">
- {generateInputDStreamsTable(jsCollector, minBatchTime, maxBatchTime, minEventRate, maxEventRate)}
+ {generateInputDStreamsTable(jsCollector, minBatchTime, maxBatchTime, minRecordRate, maxRecordRate)}
</td>
</tr>
}}
@@ -390,15 +390,16 @@ private[ui] class StreamingPage(parent: StreamingTab)
maxX: Long,
minY: Double,
maxY: Double): Seq[Node] = {
- val maxYCalculated = listener.receivedEventRateWithBatchTime.values
- .flatMap { case streamAndRates => streamAndRates.map { case (_, eventRate) => eventRate } }
+ val maxYCalculated = listener.receivedRecordRateWithBatchTime.values
+ .flatMap { case streamAndRates => streamAndRates.map { case (_, recordRate) => recordRate } }
.reduceOption[Double](math.max)
.map(_.ceil.toLong)
.getOrElse(0L)
- val content = listener.receivedEventRateWithBatchTime.toList.sortBy(_._1).map {
- case (streamId, eventRates) =>
- generateInputDStreamRow(jsCollector, streamId, eventRates, minX, maxX, minY, maxYCalculated)
+ val content = listener.receivedRecordRateWithBatchTime.toList.sortBy(_._1).map {
+ case (streamId, recordRates) =>
+ generateInputDStreamRow(
+ jsCollector, streamId, recordRates, minX, maxX, minY, maxYCalculated)
}.foldLeft[Seq[Node]](Nil)(_ ++ _)
// scalastyle:off
@@ -422,7 +423,7 @@ private[ui] class StreamingPage(parent: StreamingTab)
private def generateInputDStreamRow(
jsCollector: JsCollector,
streamId: Int,
- eventRates: Seq[(Long, Double)],
+ recordRates: Seq[(Long, Double)],
minX: Long,
maxX: Long,
minY: Double,
@@ -447,25 +448,25 @@ private[ui] class StreamingPage(parent: StreamingTab)
val receiverLastErrorTime = receiverInfo.map {
r => if (r.lastErrorTime < 0) "-" else SparkUIUtils.formatDate(r.lastErrorTime)
}.getOrElse(emptyCell)
- val receivedRecords = new EventRateUIData(eventRates)
+ val receivedRecords = new RecordRateUIData(recordRates)
- val graphUIDataForEventRate =
+ val graphUIDataForRecordRate =
new GraphUIData(
- s"stream-$streamId-events-timeline",
- s"stream-$streamId-events-histogram",
+ s"stream-$streamId-records-timeline",
+ s"stream-$streamId-records-histogram",
receivedRecords.data,
minX,
maxX,
minY,
maxY,
- "events/sec")
- graphUIDataForEventRate.generateDataJs(jsCollector)
+ "records/sec")
+ graphUIDataForRecordRate.generateDataJs(jsCollector)
<tr>
<td rowspan="2" style="vertical-align: middle; width: 151px;">
<div style="width: 151px;">
<div style="word-wrap: break-word;"><strong>{receiverName}</strong></div>
- <div>Avg: {receivedRecords.formattedAvg} events/sec</div>
+ <div>Avg: {receivedRecords.formattedAvg} records/sec</div>
</div>
</td>
<td>{receiverActive}</td>
@@ -475,9 +476,9 @@ private[ui] class StreamingPage(parent: StreamingTab)
</tr>
<tr>
<td colspan="3" class="timeline">
- {graphUIDataForEventRate.generateTimelineHtml(jsCollector)}
+ {graphUIDataForRecordRate.generateTimelineHtml(jsCollector)}
</td>
- <td class="histogram">{graphUIDataForEventRate.generateHistogramHtml(jsCollector)}</td>
+ <td class="histogram">{graphUIDataForRecordRate.generateHistogramHtml(jsCollector)}</td>
</tr>
}
diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties
index 75e3b53a09..fd51f8faf5 100644
--- a/streaming/src/test/resources/log4j.properties
+++ b/streaming/src/test/resources/log4j.properties
@@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
-log4j.logger.org.spark-project.jetty=WARN
+log4j.logger.org.spark_project.jetty=WARN
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index bd60059b18..cfcbdc7c38 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -538,10 +538,9 @@ class BasicOperationsSuite extends TestSuiteBase {
val stateObj = state.getOrElse(new StateObject)
values.sum match {
case 0 => stateObj.expireCounter += 1 // no new values
- case n => { // has new values, increment and reset expireCounter
+ case n => // has new values, increment and reset expireCounter
stateObj.counter += n
stateObj.expireCounter = 0
- }
}
stateObj.expireCounter match {
case 2 => None // seen twice with no new values, give it the boot
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index 9a3248b3e8..bdbac64b9b 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -228,6 +228,11 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester
}
}
+ test("non-existent checkpoint dir") {
+ // SPARK-13211
+ intercept[IllegalArgumentException](new StreamingContext("nosuchdirectory"))
+ }
+
test("basic rdd checkpoints + dstream graph checkpoint recovery") {
assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second")
@@ -262,10 +267,9 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester
assert(!stateStream.checkpointData.currentCheckpointFiles.isEmpty,
"No checkpointed RDDs in state stream before first failure")
stateStream.checkpointData.currentCheckpointFiles.foreach {
- case (time, file) => {
+ case (time, file) =>
assert(fs.exists(new Path(file)), "Checkpoint file '" + file +"' for time " + time +
" for state stream before first failure does not exist")
- }
}
// Run till a further time such that previous checkpoint files in the stream would be deleted
@@ -292,10 +296,9 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester
assert(!stateStream.checkpointData.currentCheckpointFiles.isEmpty,
"No checkpointed RDDs in state stream before second failure")
stateStream.checkpointData.currentCheckpointFiles.foreach {
- case (time, file) => {
+ case (time, file) =>
assert(fs.exists(new Path(file)), "Checkpoint file '" + file +"' for time " + time +
" for state stream before seconds failure does not exist")
- }
}
ssc.stop()
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala
index 29bee4adf2..60c8e70235 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala
@@ -382,11 +382,10 @@ class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long)
fs.rename(tempHadoopFile, hadoopFile)
done = true
} catch {
- case ioe: IOException => {
+ case ioe: IOException =>
fs = testDir.getFileSystem(new Configuration())
logWarning("Attempt " + tries + " at generating file " + hadoopFile + " failed.",
ioe)
- }
}
}
if (!done) {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
index 4e77cd6347..39d0de5179 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
@@ -157,7 +157,8 @@ class ReceivedBlockHandlerSuite
val reader = new FileBasedWriteAheadLogRandomReader(fileSegment.path, hadoopConf)
val bytes = reader.read(fileSegment)
reader.close()
- serializerManager.dataDeserialize(generateBlockId(), new ChunkedByteBuffer(bytes)).toList
+ serializerManager.dataDeserializeStream(
+ generateBlockId(), new ChunkedByteBuffer(bytes).toInputStream()).toList
}
loggedData shouldEqual data
}
@@ -265,7 +266,7 @@ class ReceivedBlockHandlerSuite
conf: SparkConf,
name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1)
- val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1)
+ val transfer = new NettyBlockTransferService(conf, securityMgr, "localhost", numCores = 1)
val blockManager = new BlockManager(name, rpcEnv, blockManagerMaster, serializerManager, conf,
memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
memManager.setMemoryStore(blockManager.memoryStore)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
index a80154e2fc..806e181f61 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -182,7 +182,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo
assert(ssc.scheduler.isStarted === false)
}
- test("start should set job group and description of streaming jobs correctly") {
+ test("start should set local properties of streaming jobs correctly") {
ssc = new StreamingContext(conf, batchDuration)
ssc.sc.setJobGroup("non-streaming", "non-streaming", true)
val sc = ssc.sc
@@ -190,16 +190,22 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo
@volatile var jobGroupFound: String = ""
@volatile var jobDescFound: String = ""
@volatile var jobInterruptFound: String = ""
+ @volatile var customPropFound: String = ""
@volatile var allFound: Boolean = false
addInputStream(ssc).foreachRDD { rdd =>
jobGroupFound = sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID)
jobDescFound = sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION)
jobInterruptFound = sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL)
+ customPropFound = sc.getLocalProperty("customPropKey")
allFound = true
}
+ ssc.sc.setLocalProperty("customPropKey", "value1")
ssc.start()
+ // Local props set after start should be ignored
+ ssc.sc.setLocalProperty("customPropKey", "value2")
+
eventually(timeout(10 seconds), interval(10 milliseconds)) {
assert(allFound === true)
}
@@ -208,11 +214,13 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo
assert(jobGroupFound === null)
assert(jobDescFound.contains("Streaming job from"))
assert(jobInterruptFound === "false")
+ assert(customPropFound === "value1")
// Verify current thread's thread-local properties have not changed
assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "non-streaming")
assert(sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) === "non-streaming")
assert(sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) === "true")
+ assert(sc.getLocalProperty("customPropKey") === "value2")
}
test("start multiple times") {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala
index 3f12de38ef..454c3dffa3 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala
@@ -169,9 +169,9 @@ class UISeleniumSuite
List("4/4", "4/4", "4/4", "0/4 (1 failed)"))
// Check stacktrace
- val errorCells = findAll(cssSelector(""".stacktrace-details""")).map(_.text).toSeq
+ val errorCells = findAll(cssSelector(""".stacktrace-details""")).map(_.underlying).toSeq
errorCells should have size 1
- errorCells(0) should include("java.lang.RuntimeException: Oops")
+ // Can't get the inner (invisible) text without running JS
// Check the job link in the batch page is right
go to (jobLinks(0))
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala
new file mode 100644
index 0000000000..7630f4a75e
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import org.mockito.Matchers.{eq => meq}
+import org.mockito.Mockito._
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, PrivateMethodTester}
+import org.scalatest.concurrent.Eventually.{eventually, timeout}
+import org.scalatest.mock.MockitoSugar
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.{ExecutorAllocationClient, SparkConf, SparkFunSuite}
+import org.apache.spark.streaming.{DummyInputDStream, Seconds, StreamingContext}
+import org.apache.spark.util.{ManualClock, Utils}
+
+
+class ExecutorAllocationManagerSuite extends SparkFunSuite
+ with BeforeAndAfter with BeforeAndAfterAll with MockitoSugar with PrivateMethodTester {
+
+ import ExecutorAllocationManager._
+
+ private val batchDurationMillis = 1000L
+ private var allocationClient: ExecutorAllocationClient = null
+ private var clock: ManualClock = null
+
+ before {
+ allocationClient = mock[ExecutorAllocationClient]
+ clock = new ManualClock()
+ }
+
+ test("basic functionality") {
+ // Test that adding batch processing time info to allocation manager
+ // causes executors to be requested and killed accordingly
+
+ // There is 1 receiver, and exec 1 has been allocated to it
+ withAllocationManager(numReceivers = 1) { case (receiverTracker, allocationManager) =>
+ when(receiverTracker.allocatedExecutors).thenReturn(Map(1 -> Some("1")))
+
+ /** Add data point for batch processing time and verify executor allocation */
+ def addBatchProcTimeAndVerifyAllocation(batchProcTimeMs: Double)(body: => Unit): Unit = {
+ // 2 active executors
+ reset(allocationClient)
+ when(allocationClient.getExecutorIds()).thenReturn(Seq("1", "2"))
+ addBatchProcTime(allocationManager, batchProcTimeMs.toLong)
+ clock.advance(SCALING_INTERVAL_DEFAULT_SECS * 1000 + 1)
+ eventually(timeout(10 seconds)) {
+ body
+ }
+ }
+
+ /** Verify that the expected number of total executor were requested */
+ def verifyTotalRequestedExecs(expectedRequestedTotalExecs: Option[Int]): Unit = {
+ if (expectedRequestedTotalExecs.nonEmpty) {
+ require(expectedRequestedTotalExecs.get > 0)
+ verify(allocationClient, times(1)).requestTotalExecutors(
+ meq(expectedRequestedTotalExecs.get), meq(0), meq(Map.empty))
+ } else {
+ verify(allocationClient, never).requestTotalExecutors(0, 0, Map.empty)
+ }
+ }
+
+ /** Verify that a particular executor was killed */
+ def verifyKilledExec(expectedKilledExec: Option[String]): Unit = {
+ if (expectedKilledExec.nonEmpty) {
+ verify(allocationClient, times(1)).killExecutor(meq(expectedKilledExec.get))
+ } else {
+ verify(allocationClient, never).killExecutor(null)
+ }
+ }
+
+ // Batch proc time = batch interval, should increase allocation by 1
+ addBatchProcTimeAndVerifyAllocation(batchDurationMillis) {
+ verifyTotalRequestedExecs(Some(3)) // one already allocated, increase allocation by 1
+ verifyKilledExec(None)
+ }
+
+ // Batch proc time = batch interval * 2, should increase allocation by 2
+ addBatchProcTimeAndVerifyAllocation(batchDurationMillis * 2) {
+ verifyTotalRequestedExecs(Some(4))
+ verifyKilledExec(None)
+ }
+
+ // Batch proc time slightly more than the scale up ratio, should increase allocation by 1
+ addBatchProcTimeAndVerifyAllocation(batchDurationMillis * SCALING_UP_RATIO_DEFAULT + 1) {
+ verifyTotalRequestedExecs(Some(3))
+ verifyKilledExec(None)
+ }
+
+ // Batch proc time slightly less than the scale up ratio, should not change allocation
+ addBatchProcTimeAndVerifyAllocation(batchDurationMillis * SCALING_UP_RATIO_DEFAULT - 1) {
+ verifyTotalRequestedExecs(None)
+ verifyKilledExec(None)
+ }
+
+ // Batch proc time slightly more than the scale down ratio, should not change allocation
+ addBatchProcTimeAndVerifyAllocation(batchDurationMillis * SCALING_DOWN_RATIO_DEFAULT + 1) {
+ verifyTotalRequestedExecs(None)
+ verifyKilledExec(None)
+ }
+
+ // Batch proc time slightly more than the scale down ratio, should not change allocation
+ addBatchProcTimeAndVerifyAllocation(batchDurationMillis * SCALING_DOWN_RATIO_DEFAULT - 1) {
+ verifyTotalRequestedExecs(None)
+ verifyKilledExec(Some("2"))
+ }
+ }
+ }
+
+ test("requestExecutors policy") {
+
+ /** Verify that the expected number of total executor were requested */
+ def verifyRequestedExecs(
+ numExecs: Int,
+ numNewExecs: Int,
+ expectedRequestedTotalExecs: Int)(
+ implicit allocationManager: ExecutorAllocationManager): Unit = {
+ reset(allocationClient)
+ when(allocationClient.getExecutorIds()).thenReturn((1 to numExecs).map(_.toString))
+ requestExecutors(allocationManager, numNewExecs)
+ verify(allocationClient, times(1)).requestTotalExecutors(
+ meq(expectedRequestedTotalExecs), meq(0), meq(Map.empty))
+ }
+
+ withAllocationManager(numReceivers = 1) { case (_, allocationManager) =>
+ implicit val am = allocationManager
+ intercept[IllegalArgumentException] {
+ verifyRequestedExecs(numExecs = 0, numNewExecs = 0, 0)
+ }
+ verifyRequestedExecs(numExecs = 0, numNewExecs = 1, expectedRequestedTotalExecs = 1)
+ verifyRequestedExecs(numExecs = 1, numNewExecs = 1, expectedRequestedTotalExecs = 2)
+ verifyRequestedExecs(numExecs = 2, numNewExecs = 2, expectedRequestedTotalExecs = 4)
+ }
+
+ withAllocationManager(numReceivers = 2) { case(_, allocationManager) =>
+ implicit val am = allocationManager
+
+ verifyRequestedExecs(numExecs = 0, numNewExecs = 1, expectedRequestedTotalExecs = 2)
+ verifyRequestedExecs(numExecs = 1, numNewExecs = 1, expectedRequestedTotalExecs = 2)
+ verifyRequestedExecs(numExecs = 2, numNewExecs = 2, expectedRequestedTotalExecs = 4)
+ }
+
+ withAllocationManager(
+ // Test min 2 executors
+ new SparkConf().set("spark.streaming.dynamicAllocation.minExecutors", "2")) {
+ case (_, allocationManager) =>
+ implicit val am = allocationManager
+
+ verifyRequestedExecs(numExecs = 0, numNewExecs = 1, expectedRequestedTotalExecs = 2)
+ verifyRequestedExecs(numExecs = 0, numNewExecs = 3, expectedRequestedTotalExecs = 3)
+ verifyRequestedExecs(numExecs = 1, numNewExecs = 1, expectedRequestedTotalExecs = 2)
+ verifyRequestedExecs(numExecs = 1, numNewExecs = 2, expectedRequestedTotalExecs = 3)
+ verifyRequestedExecs(numExecs = 2, numNewExecs = 1, expectedRequestedTotalExecs = 3)
+ verifyRequestedExecs(numExecs = 2, numNewExecs = 2, expectedRequestedTotalExecs = 4)
+ }
+
+ withAllocationManager(
+ // Test with max 2 executors
+ new SparkConf().set("spark.streaming.dynamicAllocation.maxExecutors", "2")) {
+ case (_, allocationManager) =>
+ implicit val am = allocationManager
+
+ verifyRequestedExecs(numExecs = 0, numNewExecs = 1, expectedRequestedTotalExecs = 1)
+ verifyRequestedExecs(numExecs = 0, numNewExecs = 3, expectedRequestedTotalExecs = 2)
+ verifyRequestedExecs(numExecs = 1, numNewExecs = 2, expectedRequestedTotalExecs = 2)
+ verifyRequestedExecs(numExecs = 2, numNewExecs = 1, expectedRequestedTotalExecs = 2)
+ verifyRequestedExecs(numExecs = 2, numNewExecs = 2, expectedRequestedTotalExecs = 2)
+ }
+ }
+
+ test("killExecutor policy") {
+
+ /**
+ * Verify that a particular executor was killed, given active executors and executors
+ * allocated to receivers.
+ */
+ def verifyKilledExec(
+ execIds: Seq[String],
+ receiverExecIds: Map[Int, Option[String]],
+ expectedKilledExec: Option[String])(
+ implicit x: (ReceiverTracker, ExecutorAllocationManager)): Unit = {
+ val (receiverTracker, allocationManager) = x
+
+ reset(allocationClient)
+ when(allocationClient.getExecutorIds()).thenReturn(execIds)
+ when(receiverTracker.allocatedExecutors).thenReturn(receiverExecIds)
+ killExecutor(allocationManager)
+ if (expectedKilledExec.nonEmpty) {
+ verify(allocationClient, times(1)).killExecutor(meq(expectedKilledExec.get))
+ } else {
+ verify(allocationClient, never).killExecutor(null)
+ }
+ }
+
+ withAllocationManager() { case (receiverTracker, allocationManager) =>
+ implicit val rcvrTrackerAndExecAllocMgr = (receiverTracker, allocationManager)
+
+ verifyKilledExec(Nil, Map.empty, None)
+ verifyKilledExec(Seq("1", "2"), Map.empty, None)
+ verifyKilledExec(Seq("1"), Map(1 -> Some("1")), None)
+ verifyKilledExec(Seq("1", "2"), Map(1 -> Some("1")), Some("2"))
+ verifyKilledExec(Seq("1", "2"), Map(1 -> Some("1"), 2 -> Some("2")), None)
+ }
+
+ withAllocationManager(
+ new SparkConf().set("spark.streaming.dynamicAllocation.minExecutors", "2")) {
+ case (receiverTracker, allocationManager) =>
+ implicit val rcvrTrackerAndExecAllocMgr = (receiverTracker, allocationManager)
+
+ verifyKilledExec(Seq("1", "2"), Map.empty, None)
+ verifyKilledExec(Seq("1", "2", "3"), Map(1 -> Some("1"), 2 -> Some("2")), Some("3"))
+ }
+ }
+
+ test("parameter validation") {
+
+ def validateParams(
+ numReceivers: Int = 1,
+ scalingIntervalSecs: Option[Int] = None,
+ scalingUpRatio: Option[Double] = None,
+ scalingDownRatio: Option[Double] = None,
+ minExecs: Option[Int] = None,
+ maxExecs: Option[Int] = None): Unit = {
+ require(numReceivers > 0)
+ val receiverTracker = mock[ReceiverTracker]
+ when(receiverTracker.numReceivers()).thenReturn(numReceivers)
+ val conf = new SparkConf()
+ if (scalingIntervalSecs.nonEmpty) {
+ conf.set(
+ "spark.streaming.dynamicAllocation.scalingInterval",
+ s"${scalingIntervalSecs.get}s")
+ }
+ if (scalingUpRatio.nonEmpty) {
+ conf.set("spark.streaming.dynamicAllocation.scalingUpRatio", scalingUpRatio.get.toString)
+ }
+ if (scalingDownRatio.nonEmpty) {
+ conf.set(
+ "spark.streaming.dynamicAllocation.scalingDownRatio",
+ scalingDownRatio.get.toString)
+ }
+ if (minExecs.nonEmpty) {
+ conf.set("spark.streaming.dynamicAllocation.minExecutors", minExecs.get.toString)
+ }
+ if (maxExecs.nonEmpty) {
+ conf.set("spark.streaming.dynamicAllocation.maxExecutors", maxExecs.get.toString)
+ }
+ new ExecutorAllocationManager(
+ allocationClient, receiverTracker, conf, batchDurationMillis, clock)
+ }
+
+ validateParams(numReceivers = 1)
+ validateParams(numReceivers = 2, minExecs = Some(1))
+ validateParams(numReceivers = 2, minExecs = Some(3))
+ validateParams(numReceivers = 2, maxExecs = Some(3))
+ validateParams(numReceivers = 2, maxExecs = Some(1))
+ validateParams(minExecs = Some(3), maxExecs = Some(3))
+ validateParams(scalingIntervalSecs = Some(1))
+ validateParams(scalingUpRatio = Some(1.1))
+ validateParams(scalingDownRatio = Some(0.1))
+ validateParams(scalingUpRatio = Some(1.1), scalingDownRatio = Some(0.1))
+
+ intercept[IllegalArgumentException] {
+ validateParams(minExecs = Some(0))
+ }
+ intercept[IllegalArgumentException] {
+ validateParams(minExecs = Some(-1))
+ }
+ intercept[IllegalArgumentException] {
+ validateParams(maxExecs = Some(0))
+ }
+ intercept[IllegalArgumentException] {
+ validateParams(maxExecs = Some(-1))
+ }
+ intercept[IllegalArgumentException] {
+ validateParams(minExecs = Some(4), maxExecs = Some(3))
+ }
+ intercept[IllegalArgumentException] {
+ validateParams(scalingIntervalSecs = Some(-1))
+ }
+ intercept[IllegalArgumentException] {
+ validateParams(scalingIntervalSecs = Some(0))
+ }
+ intercept[IllegalArgumentException] {
+ validateParams(scalingUpRatio = Some(-0.1))
+ }
+ intercept[IllegalArgumentException] {
+ validateParams(scalingUpRatio = Some(0))
+ }
+ intercept[IllegalArgumentException] {
+ validateParams(scalingDownRatio = Some(-0.1))
+ }
+ intercept[IllegalArgumentException] {
+ validateParams(scalingDownRatio = Some(0))
+ }
+ intercept[IllegalArgumentException] {
+ validateParams(scalingUpRatio = Some(0.5), scalingDownRatio = Some(0.5))
+ }
+ intercept[IllegalArgumentException] {
+ validateParams(scalingUpRatio = Some(0.3), scalingDownRatio = Some(0.5))
+ }
+ }
+
+ test("enabling and disabling") {
+ withStreamingContext(new SparkConf()) { ssc =>
+ ssc.start()
+ assert(getExecutorAllocationManager(ssc).isEmpty)
+ }
+
+ withStreamingContext(
+ new SparkConf().set("spark.streaming.dynamicAllocation.enabled", "true")) { ssc =>
+ ssc.start()
+ assert(getExecutorAllocationManager(ssc).nonEmpty)
+ }
+
+ val confWithBothDynamicAllocationEnabled = new SparkConf()
+ .set("spark.streaming.dynamicAllocation.enabled", "true")
+ .set("spark.dynamicAllocation.enabled", "true")
+ .set("spark.dynamicAllocation.testing", "true")
+ require(Utils.isDynamicAllocationEnabled(confWithBothDynamicAllocationEnabled) === true)
+ withStreamingContext(confWithBothDynamicAllocationEnabled) { ssc =>
+ intercept[IllegalArgumentException] {
+ ssc.start()
+ }
+ }
+ }
+
+ private def withAllocationManager(
+ conf: SparkConf = new SparkConf,
+ numReceivers: Int = 1
+ )(body: (ReceiverTracker, ExecutorAllocationManager) => Unit): Unit = {
+
+ val receiverTracker = mock[ReceiverTracker]
+ when(receiverTracker.numReceivers()).thenReturn(numReceivers)
+
+ val manager = new ExecutorAllocationManager(
+ allocationClient, receiverTracker, conf, batchDurationMillis, clock)
+ try {
+ manager.start()
+ body(receiverTracker, manager)
+ } finally {
+ manager.stop()
+ }
+ }
+
+ private val _addBatchProcTime = PrivateMethod[Unit]('addBatchProcTime)
+ private val _requestExecutors = PrivateMethod[Unit]('requestExecutors)
+ private val _killExecutor = PrivateMethod[Unit]('killExecutor)
+ private val _executorAllocationManager =
+ PrivateMethod[Option[ExecutorAllocationManager]]('executorAllocationManager)
+
+ private def addBatchProcTime(manager: ExecutorAllocationManager, timeMs: Long): Unit = {
+ manager invokePrivate _addBatchProcTime(timeMs)
+ }
+
+ private def requestExecutors(manager: ExecutorAllocationManager, newExecs: Int): Unit = {
+ manager invokePrivate _requestExecutors(newExecs)
+ }
+
+ private def killExecutor(manager: ExecutorAllocationManager): Unit = {
+ manager invokePrivate _killExecutor()
+ }
+
+ private def getExecutorAllocationManager(
+ ssc: StreamingContext): Option[ExecutorAllocationManager] = {
+ ssc.scheduler invokePrivate _executorAllocationManager()
+ }
+
+ private def withStreamingContext(conf: SparkConf)(body: StreamingContext => Unit): Unit = {
+ conf.setMaster("local").setAppName(this.getClass.getSimpleName).set(
+ "spark.streaming.dynamicAllocation.testing", "true") // to test dynamic allocation
+
+ var ssc: StreamingContext = null
+ try {
+ ssc = new StreamingContext(conf, Seconds(1))
+ new DummyInputDStream(ssc).foreachRDD(_ => { })
+ body(ssc)
+ } finally {
+ if (ssc != null) ssc.stop()
+ }
+ }
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
index 7654bb2d03..df122ac090 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart, TaskLo
import org.apache.spark.scheduler.TaskLocality.TaskLocality
import org.apache.spark.storage.{StorageLevel, StreamBlockId}
import org.apache.spark.streaming._
-import org.apache.spark.streaming.dstream.ReceiverInputDStream
+import org.apache.spark.streaming.dstream.{ConstantInputDStream, ReceiverInputDStream}
import org.apache.spark.streaming.receiver._
/** Testsuite for receiver scheduling */
@@ -102,6 +102,27 @@ class ReceiverTrackerSuite extends TestSuiteBase {
}
}
}
+
+ test("get allocated executors") {
+ // Test get allocated executors when 1 receiver is registered
+ withStreamingContext(new StreamingContext(conf, Milliseconds(100))) { ssc =>
+ val input = ssc.receiverStream(new TestReceiver)
+ val output = new TestOutputStream(input)
+ output.register()
+ ssc.start()
+ assert(ssc.scheduler.receiverTracker.allocatedExecutors().size === 1)
+ }
+
+ // Test get allocated executors when there's no receiver registered
+ withStreamingContext(new StreamingContext(conf, Milliseconds(100))) { ssc =>
+ val rdd = ssc.sc.parallelize(1 to 10)
+ val input = new ConstantInputDStream(ssc, rdd)
+ val output = new TestOutputStream(input)
+ output.register()
+ ssc.start()
+ assert(ssc.scheduler.receiverTracker.allocatedExecutors() === Map.empty)
+ }
+ }
}
/** An input DStream with for testing rate controlling */
diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala
index 0df3c501de..c9058ff409 100644
--- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala
+++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala
@@ -91,10 +91,11 @@ object GenerateMIMAIgnore {
(ignoredClasses.flatMap(c => Seq(c, c.replace("$", "#"))).toSet, ignoredMembers.toSet)
}
- /** Scala reflection does not let us see inner function even if they are upgraded
- * to public for some reason. So had to resort to java reflection to get all inner
- * functions with $$ in there name.
- */
+ /**
+ * Scala reflection does not let us see inner function even if they are upgraded
+ * to public for some reason. So had to resort to java reflection to get all inner
+ * functions with $$ in there name.
+ */
def getInnerFunctions(classSymbol: unv.ClassSymbol): Seq[String] = {
try {
Class.forName(classSymbol.fullName, false, classLoader).getMethods.map(_.getName)
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index e941089d1b..d447a59937 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -374,7 +374,7 @@ private[spark] class ApplicationMaster(
failureCount = 0
} catch {
case i: InterruptedException =>
- case e: Throwable => {
+ case e: Throwable =>
failureCount += 1
// this exception was introduced in hadoop 2.4 and this code would not compile
// with earlier versions if we refer it directly.
@@ -390,7 +390,6 @@ private[spark] class ApplicationMaster(
} else {
logWarning(s"Reporter thread fails $failureCount time(s) in a row.", e)
}
- }
}
try {
val numPendingAllocate = allocator.getPendingAllocate.size
@@ -662,7 +661,7 @@ object ApplicationMaster extends Logging {
SignalLogger.register(log)
val amArgs = new ApplicationMasterArguments(args)
SparkHadoopUtil.get.runAsSparkUser { () =>
- master = new ApplicationMaster(amArgs, new YarnRMClient(amArgs))
+ master = new ApplicationMaster(amArgs, new YarnRMClient)
System.exit(master.run())
}
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
index 6987e5a55f..5cdec87667 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
@@ -27,8 +27,6 @@ class ApplicationMasterArguments(val args: Array[String]) {
var primaryPyFile: String = null
var primaryRFile: String = null
var userArgs: Seq[String] = Nil
- var executorMemory = 1024
- var executorCores = 1
var propertiesFile: String = null
parseArgs(args.toList)
@@ -58,18 +56,10 @@ class ApplicationMasterArguments(val args: Array[String]) {
primaryRFile = value
args = tail
- case ("--args" | "--arg") :: value :: tail =>
+ case ("--arg") :: value :: tail =>
userArgsBuffer += value
args = tail
- case ("--worker-memory" | "--executor-memory") :: MemoryParam(value) :: tail =>
- executorMemory = value
- args = tail
-
- case ("--worker-cores" | "--executor-cores") :: IntParam(value) :: tail =>
- executorCores = value
- args = tail
-
case ("--properties-file") :: value :: tail =>
propertiesFile = value
args = tail
@@ -101,12 +91,8 @@ class ApplicationMasterArguments(val args: Array[String]) {
| --class CLASS_NAME Name of your application's main class
| --primary-py-file A main Python file
| --primary-r-file A main R file
- | --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to
- | place on the PYTHONPATH for Python apps.
- | --args ARGS Arguments to be passed to your application's main class.
+ | --arg ARG Argument to be passed to your application's main class.
| Multiple invocations are possible, each will be passed in order.
- | --executor-cores NUM Number of cores for the executors (Default: 1)
- | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G)
| --properties-file FILE Path to a custom Spark properties file.
""".stripMargin)
// scalastyle:on println
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 6bbc8c2dfa..04e91f8553 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -64,21 +64,44 @@ private[spark] class Client(
extends Logging {
import Client._
+ import YarnSparkHadoopUtil._
def this(clientArgs: ClientArguments, spConf: SparkConf) =
this(clientArgs, SparkHadoopUtil.get.newConfiguration(spConf), spConf)
private val yarnClient = YarnClient.createYarnClient
private val yarnConf = new YarnConfiguration(hadoopConf)
- private var credentials: Credentials = null
- private val amMemoryOverhead = args.amMemoryOverhead // MB
- private val executorMemoryOverhead = args.executorMemoryOverhead // MB
+
+ private val isClusterMode = sparkConf.get("spark.submit.deployMode", "client") == "cluster"
+
+ // AM related configurations
+ private val amMemory = if (isClusterMode) {
+ sparkConf.get(DRIVER_MEMORY).toInt
+ } else {
+ sparkConf.get(AM_MEMORY).toInt
+ }
+ private val amMemoryOverhead = {
+ val amMemoryOverheadEntry = if (isClusterMode) DRIVER_MEMORY_OVERHEAD else AM_MEMORY_OVERHEAD
+ sparkConf.get(amMemoryOverheadEntry).getOrElse(
+ math.max((MEMORY_OVERHEAD_FACTOR * amMemory).toLong, MEMORY_OVERHEAD_MIN)).toInt
+ }
+ private val amCores = if (isClusterMode) {
+ sparkConf.get(DRIVER_CORES)
+ } else {
+ sparkConf.get(AM_CORES)
+ }
+
+ // Executor related configurations
+ private val executorMemory = sparkConf.get(EXECUTOR_MEMORY)
+ private val executorMemoryOverhead = sparkConf.get(EXECUTOR_MEMORY_OVERHEAD).getOrElse(
+ math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toLong, MEMORY_OVERHEAD_MIN)).toInt
+
private val distCacheMgr = new ClientDistributedCacheManager()
- private val isClusterMode = args.isClusterMode
private var loginFromKeytab = false
private var principal: String = null
private var keytab: String = null
+ private var credentials: Credentials = null
private val launcherBackend = new LauncherBackend() {
override def onStopRequest(): Unit = {
@@ -159,8 +182,8 @@ private[spark] class Client(
val appStagingDir = getAppStagingDir(appId)
try {
val preserveFiles = sparkConf.get(PRESERVE_STAGING_FILES)
- val stagingDirPath = new Path(appStagingDir)
val fs = FileSystem.get(hadoopConf)
+ val stagingDirPath = getAppStagingDirPath(sparkConf, fs, appStagingDir)
if (!preserveFiles && fs.exists(stagingDirPath)) {
logInfo("Deleting staging directory " + stagingDirPath)
fs.delete(stagingDirPath, true)
@@ -179,8 +202,8 @@ private[spark] class Client(
newApp: YarnClientApplication,
containerContext: ContainerLaunchContext): ApplicationSubmissionContext = {
val appContext = newApp.getApplicationSubmissionContext
- appContext.setApplicationName(args.appName)
- appContext.setQueue(args.amQueue)
+ appContext.setApplicationName(sparkConf.get("spark.app.name", "Spark"))
+ appContext.setQueue(sparkConf.get(QUEUE_NAME))
appContext.setAMContainerSpec(containerContext)
appContext.setApplicationType("SPARK")
@@ -217,8 +240,8 @@ private[spark] class Client(
}
val capability = Records.newRecord(classOf[Resource])
- capability.setMemory(args.amMemory + amMemoryOverhead)
- capability.setVirtualCores(args.amCores)
+ capability.setMemory(amMemory + amMemoryOverhead)
+ capability.setVirtualCores(amCores)
sparkConf.get(AM_NODE_LABEL_EXPRESSION) match {
case Some(expr) =>
@@ -272,16 +295,16 @@ private[spark] class Client(
val maxMem = newAppResponse.getMaximumResourceCapability().getMemory()
logInfo("Verifying our application has not requested more than the maximum " +
s"memory capability of the cluster ($maxMem MB per container)")
- val executorMem = args.executorMemory + executorMemoryOverhead
+ val executorMem = executorMemory + executorMemoryOverhead
if (executorMem > maxMem) {
- throw new IllegalArgumentException(s"Required executor memory (${args.executorMemory}" +
+ throw new IllegalArgumentException(s"Required executor memory ($executorMemory" +
s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " +
"Please check the values of 'yarn.scheduler.maximum-allocation-mb' and/or " +
"'yarn.nodemanager.resource.memory-mb'.")
}
- val amMem = args.amMemory + amMemoryOverhead
+ val amMem = amMemory + amMemoryOverhead
if (amMem > maxMem) {
- throw new IllegalArgumentException(s"Required AM memory (${args.amMemory}" +
+ throw new IllegalArgumentException(s"Required AM memory ($amMemory" +
s"+$amMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " +
"Please increase the value of 'yarn.scheduler.maximum-allocation-mb'.")
}
@@ -334,7 +357,7 @@ private[spark] class Client(
// Upload Spark and the application JAR to the remote file system if necessary,
// and add them as local resources to the application master.
val fs = FileSystem.get(hadoopConf)
- val dst = new Path(fs.getHomeDirectory(), appStagingDir)
+ val dst = getAppStagingDirPath(sparkConf, fs, appStagingDir)
val nns = YarnSparkHadoopUtil.get.getNameNodesToAccess(sparkConf) + dst
YarnSparkHadoopUtil.get.obtainTokensForNamenodes(nns, hadoopConf, credentials)
// Used to keep track of URIs added to the distributed cache. If the same URI is added
@@ -351,14 +374,6 @@ private[spark] class Client(
val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
- val oldLog4jConf = Option(System.getenv("SPARK_LOG4J_CONF"))
- if (oldLog4jConf.isDefined) {
- logWarning(
- "SPARK_LOG4J_CONF detected in the system environment. This variable has been " +
- "deprecated. Please refer to the \"Launching Spark on YARN\" documentation " +
- "for alternatives.")
- }
-
def addDistributedUri(uri: URI): Boolean = {
val uriStr = uri.toString()
if (distributedUris.contains(uriStr)) {
@@ -432,9 +447,6 @@ private[spark] class Client(
*
* Note that the archive cannot be a "local" URI. If none of the above settings are found,
* then upload all files found in $SPARK_HOME/jars.
- *
- * TODO: currently the code looks in $SPARK_HOME/lib while the work to replace assemblies
- * with a directory full of jars is ongoing.
*/
val sparkArchive = sparkConf.get(SPARK_ARCHIVE)
if (sparkArchive.isDefined) {
@@ -468,37 +480,27 @@ private[spark] class Client(
// No configuration, so fall back to uploading local jar files.
logWarning(s"Neither ${SPARK_JARS.key} nor ${SPARK_ARCHIVE.key} is set, falling back " +
"to uploading libraries under SPARK_HOME.")
- val jarsDir = new File(sparkConf.getenv("SPARK_HOME"), "lib")
- if (jarsDir.isDirectory()) {
- jarsDir.listFiles().foreach { f =>
- if (f.isFile() && f.getName().toLowerCase().endsWith(".jar")) {
- distribute(f.getAbsolutePath(), targetDir = Some(LOCALIZED_LIB_DIR))
- }
+ val jarsDir = new File(YarnCommandBuilderUtils.findJarsDir(
+ sparkConf.getenv("SPARK_HOME")))
+ jarsDir.listFiles().foreach { f =>
+ if (f.isFile() && f.getName().toLowerCase().endsWith(".jar")) {
+ distribute(f.getAbsolutePath(), targetDir = Some(LOCALIZED_LIB_DIR))
}
}
}
}
/**
- * Copy a few resources to the distributed cache if their scheme is not "local".
+ * Copy user jar to the distributed cache if their scheme is not "local".
* Otherwise, set the corresponding key in our SparkConf to handle it downstream.
- * Each resource is represented by a 3-tuple of:
- * (1) destination resource name,
- * (2) local path to the resource,
- * (3) Spark property key to set if the scheme is not local
*/
- List(
- (APP_JAR_NAME, args.userJar, APP_JAR),
- ("log4j.properties", oldLog4jConf.orNull, null)
- ).foreach { case (destName, path, confKey) =>
- if (path != null && !path.trim().isEmpty()) {
- val (isLocal, localizedPath) = distribute(path, destName = Some(destName))
- if (isLocal && confKey != null) {
- require(localizedPath != null, s"Path $path already distributed.")
- // If the resource is intended for local use only, handle this downstream
- // by setting the appropriate property
- sparkConf.set(confKey, localizedPath)
- }
+ Option(args.userJar).filter(_.trim.nonEmpty).foreach { jar =>
+ val (isLocal, localizedPath) = distribute(jar, destName = Some(APP_JAR_NAME))
+ if (isLocal) {
+ require(localizedPath != null, s"Path $jar already distributed")
+ // If the resource is intended for local use only, handle this downstream
+ // by setting the appropriate property
+ sparkConf.set(APP_JAR, localizedPath)
}
}
@@ -511,17 +513,15 @@ private[spark] class Client(
*/
val cachedSecondaryJarLinks = ListBuffer.empty[String]
List(
- (args.addJars, LocalResourceType.FILE, true),
- (args.files, LocalResourceType.FILE, false),
- (args.archives, LocalResourceType.ARCHIVE, false)
+ (sparkConf.get(JARS_TO_DISTRIBUTE), LocalResourceType.FILE, true),
+ (sparkConf.get(FILES_TO_DISTRIBUTE), LocalResourceType.FILE, false),
+ (sparkConf.get(ARCHIVES_TO_DISTRIBUTE), LocalResourceType.ARCHIVE, false)
).foreach { case (flist, resType, addToClasspath) =>
- if (flist != null && !flist.isEmpty()) {
- flist.split(',').foreach { file =>
- val (_, localizedPath) = distribute(file, resType = resType)
- require(localizedPath != null)
- if (addToClasspath) {
- cachedSecondaryJarLinks += localizedPath
- }
+ flist.foreach { file =>
+ val (_, localizedPath) = distribute(file, resType = resType)
+ require(localizedPath != null)
+ if (addToClasspath) {
+ cachedSecondaryJarLinks += localizedPath
}
}
}
@@ -537,16 +537,15 @@ private[spark] class Client(
// The python files list needs to be treated especially. All files that are not an
// archive need to be placed in a subdirectory that will be added to PYTHONPATH.
- args.pyFiles.foreach { f =>
+ sparkConf.get(PY_FILES).foreach { f =>
val targetDir = if (f.endsWith(".py")) Some(LOCALIZED_PYTHON_DIR) else None
distribute(f, targetDir = targetDir)
}
- // Distribute an archive with Hadoop and Spark configuration for the AM.
+ // Distribute an archive with Hadoop and Spark configuration for the AM and executors.
val (_, confLocalizedPath) = distribute(createConfArchive().toURI().getPath(),
resType = LocalResourceType.ARCHIVE,
- destName = Some(LOCALIZED_CONF_DIR),
- appMasterOnly = true)
+ destName = Some(LOCALIZED_CONF_DIR))
require(confLocalizedPath != null)
localResources
@@ -555,10 +554,10 @@ private[spark] class Client(
/**
* Create an archive with the config files for distribution.
*
- * These are only used by the AM, since executors will use the configuration object broadcast by
- * the driver. The files are zipped and added to the job as an archive, so that YARN will explode
- * it when distributing to the AM. This directory is then added to the classpath of the AM
- * process, just to make sure that everybody is using the same default config.
+ * These will be used by AM and executors. The files are zipped and added to the job as an
+ * archive, so that YARN will explode it when distributing to AM and executors. This directory
+ * is then added to the classpath of AM and executor process, just to make sure that everybody
+ * is using the same default config.
*
* This follows the order of precedence set by the startup scripts, in which HADOOP_CONF_DIR
* shows up in the classpath before YARN_CONF_DIR.
@@ -577,11 +576,14 @@ private[spark] class Client(
// required when user changes log4j.properties directly to set the log configurations. If
// configuration file is provided through --files then executors will be taking configurations
// from --files instead of $SPARK_CONF_DIR/log4j.properties.
- val log4jFileName = "log4j.properties"
- Option(Utils.getContextOrSparkClassLoader.getResource(log4jFileName)).foreach { url =>
- if (url.getProtocol == "file") {
- hadoopConfFiles(log4jFileName) = new File(url.getPath)
- }
+
+ // Also uploading metrics.properties to distributed cache if exists in classpath.
+ // If user specify this file using --files then executors will use the one
+ // from --files instead.
+ for { prop <- Seq("log4j.properties", "metrics.properties")
+ url <- Option(Utils.getContextOrSparkClassLoader.getResource(prop))
+ if url.getProtocol == "file" } {
+ hadoopConfFiles(prop) = new File(url.getPath)
}
Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey =>
@@ -660,13 +662,13 @@ private[spark] class Client(
pySparkArchives: Seq[String]): HashMap[String, String] = {
logInfo("Setting up the launch environment for our AM container")
val env = new HashMap[String, String]()
- populateClasspath(args, yarnConf, sparkConf, env, true, sparkConf.get(DRIVER_CLASS_PATH))
+ populateClasspath(args, yarnConf, sparkConf, env, sparkConf.get(DRIVER_CLASS_PATH))
env("SPARK_YARN_MODE") = "true"
env("SPARK_YARN_STAGING_DIR") = stagingDir
env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName()
if (loginFromKeytab) {
val remoteFs = FileSystem.get(hadoopConf)
- val stagingDirPath = new Path(remoteFs.getHomeDirectory, stagingDir)
+ val stagingDirPath = getAppStagingDirPath(sparkConf, remoteFs, stagingDir)
val credentialsFile = "credentials-" + UUID.randomUUID().toString
sparkConf.set(CREDENTIALS_FILE_PATH, new Path(stagingDirPath, credentialsFile).toString)
logInfo(s"Credentials file set to: $credentialsFile")
@@ -694,7 +696,7 @@ private[spark] class Client(
//
// NOTE: the code currently does not handle .py files defined with a "local:" scheme.
val pythonPath = new ListBuffer[String]()
- val (pyFiles, pyArchives) = args.pyFiles.partition(_.endsWith(".py"))
+ val (pyFiles, pyArchives) = sparkConf.get(PY_FILES).partition(_.endsWith(".py"))
if (pyFiles.nonEmpty) {
pythonPath += buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD),
LOCALIZED_PYTHON_DIR)
@@ -791,7 +793,7 @@ private[spark] class Client(
var prefixEnv: Option[String] = None
// Add Xmx for AM memory
- javaOpts += "-Xmx" + args.amMemory + "m"
+ javaOpts += "-Xmx" + amMemory + "m"
val tmpDir = new Path(
YarnSparkHadoopUtil.expandEnvironment(Environment.PWD),
@@ -837,16 +839,16 @@ private[spark] class Client(
// Validate and include yarn am specific java options in yarn-client mode.
sparkConf.get(AM_JAVA_OPTIONS).foreach { opts =>
if (opts.contains("-Dspark")) {
- val msg = s"$${amJavaOptions.key} is not allowed to set Spark options (was '$opts'). "
+ val msg = s"${AM_JAVA_OPTIONS.key} is not allowed to set Spark options (was '$opts')."
throw new SparkException(msg)
}
- if (opts.contains("-Xmx") || opts.contains("-Xms")) {
- val msg = s"$${amJavaOptions.key} is not allowed to alter memory settings (was '$opts')."
+ if (opts.contains("-Xmx")) {
+ val msg = s"${AM_JAVA_OPTIONS.key} is not allowed to specify max heap memory settings " +
+ s"(was '$opts'). Use spark.yarn.am.memory instead."
throw new SparkException(msg)
}
javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell)
}
-
sparkConf.get(AM_LIBRARY_PATH).foreach { paths =>
prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(paths))))
}
@@ -895,8 +897,6 @@ private[spark] class Client(
val amArgs =
Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ primaryRFile ++
userArgs ++ Seq(
- "--executor-memory", args.executorMemory.toString + "m",
- "--executor-cores", args.executorCores.toString,
"--properties-file", buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD),
LOCALIZED_CONF_DIR, SPARK_CONF_FILE))
@@ -935,10 +935,10 @@ private[spark] class Client(
}
def setupCredentials(): Unit = {
- loginFromKeytab = args.principal != null || sparkConf.contains(PRINCIPAL.key)
+ loginFromKeytab = sparkConf.contains(PRINCIPAL.key)
if (loginFromKeytab) {
- principal = Option(args.principal).orElse(sparkConf.get(PRINCIPAL)).get
- keytab = Option(args.keytab).orElse(sparkConf.get(KEYTAB)).orNull
+ principal = sparkConf.get(PRINCIPAL).get
+ keytab = sparkConf.get(KEYTAB).orNull
require(keytab != null, "Keytab must be specified when principal is specified.")
logInfo("Attempting to login to the Kerberos" +
@@ -1100,7 +1100,7 @@ private[spark] class Client(
}
-object Client extends Logging {
+private object Client extends Logging {
def main(argStrings: Array[String]) {
if (!sys.props.contains("SPARK_SUBMIT")) {
@@ -1113,11 +1113,7 @@ object Client extends Logging {
System.setProperty("SPARK_YARN_MODE", "true")
val sparkConf = new SparkConf
- val args = new ClientArguments(argStrings, sparkConf)
- // to maintain backwards-compatibility
- if (!Utils.isDynamicAllocationEnabled(sparkConf)) {
- sparkConf.setIfMissing(EXECUTOR_INSTANCES, args.numExecutors)
- }
+ val args = new ClientArguments(argStrings)
new Client(args, sparkConf).run()
}
@@ -1237,18 +1233,16 @@ object Client extends Logging {
conf: Configuration,
sparkConf: SparkConf,
env: HashMap[String, String],
- isAM: Boolean,
extraClassPath: Option[String] = None): Unit = {
extraClassPath.foreach { cp =>
addClasspathEntry(getClusterPath(sparkConf, cp), env)
}
+
addClasspathEntry(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env)
- if (isAM) {
- addClasspathEntry(
- YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR +
- LOCALIZED_CONF_DIR, env)
- }
+ addClasspathEntry(
+ YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR +
+ LOCALIZED_CONF_DIR, env)
if (sparkConf.get(USER_CLASS_PATH_FIRST)) {
// in order to properly add the app jar when user classpath is first
@@ -1264,7 +1258,7 @@ object Client extends Logging {
val secondaryJars =
if (args != null) {
- getSecondaryJarUris(Option(args.addJars).map(_.split(",").toSeq))
+ getSecondaryJarUris(Option(sparkConf.get(JARS_TO_DISTRIBUTE)))
} else {
getSecondaryJarUris(sparkConf.get(SECONDARY_JARS))
}
@@ -1444,4 +1438,16 @@ object Client extends Logging {
uri.startsWith(s"$LOCAL_SCHEME:")
}
+ /**
+ * Returns the app staging dir based on the STAGING_DIR configuration if configured
+ * otherwise based on the users home directory.
+ */
+ private def getAppStagingDirPath(
+ conf: SparkConf,
+ fs: FileSystem,
+ appStagingDir: String): Path = {
+ val baseDir = conf.get(STAGING_DIR).map { new Path(_) }.getOrElse(fs.getHomeDirectory())
+ new Path(baseDir, appStagingDir)
+ }
+
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
index 47b4cc3009..61c027ec44 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
@@ -19,118 +19,20 @@ package org.apache.spark.deploy.yarn
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.{SparkConf, SparkException}
-import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._
-import org.apache.spark.deploy.yarn.config._
-import org.apache.spark.internal.config._
-import org.apache.spark.util.{IntParam, MemoryParam, Utils}
-
// TODO: Add code and support for ensuring that yarn resource 'tasks' are location aware !
-private[spark] class ClientArguments(
- args: Array[String],
- sparkConf: SparkConf) {
+private[spark] class ClientArguments(args: Array[String]) {
- var addJars: String = null
- var files: String = null
- var archives: String = null
var userJar: String = null
var userClass: String = null
- var pyFiles: Seq[String] = Nil
var primaryPyFile: String = null
var primaryRFile: String = null
var userArgs: ArrayBuffer[String] = new ArrayBuffer[String]()
- var executorMemory = 1024 // MB
- var executorCores = 1
- var numExecutors = DEFAULT_NUMBER_EXECUTORS
- var amQueue = sparkConf.get(QUEUE_NAME)
- var amMemory: Int = _
- var amCores: Int = _
- var appName: String = "Spark"
- var priority = 0
- var principal: String = null
- var keytab: String = null
- def isClusterMode: Boolean = userClass != null
-
- private var driverMemory: Int = Utils.DEFAULT_DRIVER_MEM_MB // MB
- private var driverCores: Int = 1
- private val isDynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(sparkConf)
parseArgs(args.toList)
- loadEnvironmentArgs()
- validateArgs()
-
- // Additional memory to allocate to containers
- val amMemoryOverheadEntry = if (isClusterMode) DRIVER_MEMORY_OVERHEAD else AM_MEMORY_OVERHEAD
- val amMemoryOverhead = sparkConf.get(amMemoryOverheadEntry).getOrElse(
- math.max((MEMORY_OVERHEAD_FACTOR * amMemory).toLong, MEMORY_OVERHEAD_MIN)).toInt
-
- val executorMemoryOverhead = sparkConf.get(EXECUTOR_MEMORY_OVERHEAD).getOrElse(
- math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toLong, MEMORY_OVERHEAD_MIN)).toInt
-
- /** Load any default arguments provided through environment variables and Spark properties. */
- private def loadEnvironmentArgs(): Unit = {
- // For backward compatibility, SPARK_YARN_DIST_{ARCHIVES/FILES} should be resolved to hdfs://,
- // while spark.yarn.dist.{archives/files} should be resolved to file:// (SPARK-2051).
- files = Option(files)
- .orElse(sparkConf.get(FILES_TO_DISTRIBUTE).map(p => Utils.resolveURIs(p)))
- .orElse(sys.env.get("SPARK_YARN_DIST_FILES"))
- .orNull
- archives = Option(archives)
- .orElse(sparkConf.get(ARCHIVES_TO_DISTRIBUTE).map(p => Utils.resolveURIs(p)))
- .orElse(sys.env.get("SPARK_YARN_DIST_ARCHIVES"))
- .orNull
- // If dynamic allocation is enabled, start at the configured initial number of executors.
- // Default to minExecutors if no initialExecutors is set.
- numExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sparkConf, numExecutors)
- principal = Option(principal)
- .orElse(sparkConf.get(PRINCIPAL))
- .orNull
- keytab = Option(keytab)
- .orElse(sparkConf.get(KEYTAB))
- .orNull
- }
-
- /**
- * Fail fast if any arguments provided are invalid.
- * This is intended to be called only after the provided arguments have been parsed.
- */
- private def validateArgs(): Unit = {
- if (numExecutors < 0 || (!isDynamicAllocationEnabled && numExecutors == 0)) {
- throw new IllegalArgumentException(
- s"""
- |Number of executors was $numExecutors, but must be at least 1
- |(or 0 if dynamic executor allocation is enabled).
- |${getUsageMessage()}
- """.stripMargin)
- }
- if (executorCores < sparkConf.get(CPUS_PER_TASK)) {
- throw new SparkException(s"Executor cores must not be less than ${CPUS_PER_TASK.key}.")
- }
- // scalastyle:off println
- if (isClusterMode) {
- for (key <- Seq(AM_MEMORY.key, AM_MEMORY_OVERHEAD.key, AM_CORES.key)) {
- if (sparkConf.contains(key)) {
- println(s"$key is set but does not apply in cluster mode.")
- }
- }
- amMemory = driverMemory
- amCores = driverCores
- } else {
- for (key <- Seq(DRIVER_MEMORY_OVERHEAD.key, DRIVER_CORES.key)) {
- if (sparkConf.contains(key)) {
- println(s"$key is set but does not apply in client mode.")
- }
- }
- amMemory = sparkConf.get(AM_MEMORY).toInt
- amCores = sparkConf.get(AM_CORES)
- }
- // scalastyle:on println
- }
private def parseArgs(inputArgs: List[String]): Unit = {
var args = inputArgs
- // scalastyle:off println
while (!args.isEmpty) {
args match {
case ("--jar") :: value :: tail =>
@@ -149,88 +51,16 @@ private[spark] class ClientArguments(
primaryRFile = value
args = tail
- case ("--args" | "--arg") :: value :: tail =>
- if (args(0) == "--args") {
- println("--args is deprecated. Use --arg instead.")
- }
+ case ("--arg") :: value :: tail =>
userArgs += value
args = tail
- case ("--master-class" | "--am-class") :: value :: tail =>
- println(s"${args(0)} is deprecated and is not used anymore.")
- args = tail
-
- case ("--master-memory" | "--driver-memory") :: MemoryParam(value) :: tail =>
- if (args(0) == "--master-memory") {
- println("--master-memory is deprecated. Use --driver-memory instead.")
- }
- driverMemory = value
- args = tail
-
- case ("--driver-cores") :: IntParam(value) :: tail =>
- driverCores = value
- args = tail
-
- case ("--num-workers" | "--num-executors") :: IntParam(value) :: tail =>
- if (args(0) == "--num-workers") {
- println("--num-workers is deprecated. Use --num-executors instead.")
- }
- numExecutors = value
- args = tail
-
- case ("--worker-memory" | "--executor-memory") :: MemoryParam(value) :: tail =>
- if (args(0) == "--worker-memory") {
- println("--worker-memory is deprecated. Use --executor-memory instead.")
- }
- executorMemory = value
- args = tail
-
- case ("--worker-cores" | "--executor-cores") :: IntParam(value) :: tail =>
- if (args(0) == "--worker-cores") {
- println("--worker-cores is deprecated. Use --executor-cores instead.")
- }
- executorCores = value
- args = tail
-
- case ("--queue") :: value :: tail =>
- amQueue = value
- args = tail
-
- case ("--name") :: value :: tail =>
- appName = value
- args = tail
-
- case ("--addJars") :: value :: tail =>
- addJars = value
- args = tail
-
- case ("--py-files") :: value :: tail =>
- pyFiles = value.split(",")
- args = tail
-
- case ("--files") :: value :: tail =>
- files = value
- args = tail
-
- case ("--archives") :: value :: tail =>
- archives = value
- args = tail
-
- case ("--principal") :: value :: tail =>
- principal = value
- args = tail
-
- case ("--keytab") :: value :: tail =>
- keytab = value
- args = tail
-
case Nil =>
case _ =>
throw new IllegalArgumentException(getUsageMessage(args))
}
}
- // scalastyle:on println
if (primaryPyFile != null && primaryRFile != null) {
throw new IllegalArgumentException("Cannot have primary-py-file and primary-r-file" +
@@ -240,7 +70,6 @@ private[spark] class ClientArguments(
private def getUsageMessage(unknownParam: List[String] = null): String = {
val message = if (unknownParam != null) s"Unknown/unsupported param $unknownParam\n" else ""
- val mem_mb = Utils.DEFAULT_DRIVER_MEM_MB
message +
s"""
|Usage: org.apache.spark.deploy.yarn.Client [options]
@@ -252,20 +81,6 @@ private[spark] class ClientArguments(
| --primary-r-file A main R file
| --arg ARG Argument to be passed to your application's main class.
| Multiple invocations are possible, each will be passed in order.
- | --num-executors NUM Number of executors to start (Default: 2)
- | --executor-cores NUM Number of cores per executor (Default: 1).
- | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: $mem_mb Mb)
- | --driver-cores NUM Number of cores used by the driver (Default: 1).
- | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G)
- | --name NAME The name of your application (Default: Spark)
- | --queue QUEUE The hadoop queue to use for allocation requests (Default:
- | 'default')
- | --addJars jars Comma separated list of local jars that want SparkContext.addJar
- | to work with.
- | --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to
- | place on the PYTHONPATH for Python apps.
- | --files files Comma separated list of files to be distributed with the job.
- | --archives archives Comma separated list of archives to be distributed with the job.
""".stripMargin
}
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
index f956a4d1d5..ef7908a3ef 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
@@ -147,7 +147,6 @@ private[yarn] class ExecutorRunnable(
// Set the JVM memory
val executorMemoryString = executorMemory + "m"
- javaOpts += "-Xms" + executorMemoryString
javaOpts += "-Xmx" + executorMemoryString
// Set extra Java options for the executor, if defined
@@ -289,8 +288,7 @@ private[yarn] class ExecutorRunnable(
private def prepareEnvironment(container: Container): HashMap[String, String] = {
val env = new HashMap[String, String]()
- Client.populateClasspath(null, yarnConf, sparkConf, env, false,
- sparkConf.get(EXECUTOR_CLASS_PATH))
+ Client.populateClasspath(null, yarnConf, sparkConf, env, sparkConf.get(EXECUTOR_CLASS_PATH))
sparkConf.getExecutorEnv.foreach { case (key, value) =>
// This assumes each executor environment variable set here is a path
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
index d094302362..23742eab62 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
@@ -36,9 +36,11 @@ import org.apache.spark.{SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._
import org.apache.spark.deploy.yarn.config._
import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef}
import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason}
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RetrieveLastAllocatedExecutorId
import org.apache.spark.util.ThreadUtils
/**
@@ -61,7 +63,6 @@ private[yarn] class YarnAllocator(
sparkConf: SparkConf,
amClient: AMRMClient[ContainerRequest],
appAttemptId: ApplicationAttemptId,
- args: ApplicationMasterArguments,
securityMgr: SecurityManager)
extends Logging {
@@ -83,8 +84,23 @@ private[yarn] class YarnAllocator(
new ConcurrentHashMap[ContainerId, java.lang.Boolean])
@volatile private var numExecutorsRunning = 0
- // Used to generate a unique ID per executor
- private var executorIdCounter = 0
+
+ /**
+ * Used to generate a unique ID per executor
+ *
+ * Init `executorIdCounter`. when AM restart, `executorIdCounter` will reset to 0. Then
+ * the id of new executor will start from 1, this will conflict with the executor has
+ * already created before. So, we should initialize the `executorIdCounter` by getting
+ * the max executorId from driver.
+ *
+ * And this situation of executorId conflict is just in yarn client mode, so this is an issue
+ * in yarn client mode. For more details, can check in jira.
+ *
+ * @see SPARK-12864
+ */
+ private var executorIdCounter: Int =
+ driverRef.askWithRetry[Int](RetrieveLastAllocatedExecutorId)
+
@volatile private var numExecutorsFailed = 0
@volatile private var targetNumExecutors =
@@ -107,12 +123,12 @@ private[yarn] class YarnAllocator(
private val containerIdToExecutorId = new HashMap[ContainerId, String]
// Executor memory in MB.
- protected val executorMemory = args.executorMemory
+ protected val executorMemory = sparkConf.get(EXECUTOR_MEMORY).toInt
// Additional memory overhead.
protected val memoryOverhead: Int = sparkConf.get(EXECUTOR_MEMORY_OVERHEAD).getOrElse(
math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN)).toInt
// Number of cores per executor.
- protected val executorCores = args.executorCores
+ protected val executorCores = sparkConf.get(EXECUTOR_CORES)
// Resource capability requested for each executors
private[yarn] val resource = Resource.newInstance(executorMemory + memoryOverhead, executorCores)
@@ -132,11 +148,10 @@ private[yarn] class YarnAllocator(
classOf[Array[String]], classOf[Array[String]], classOf[Priority], classOf[Boolean],
classOf[String]))
} catch {
- case e: NoSuchMethodException => {
+ case e: NoSuchMethodException =>
logWarning(s"Node label expression $expr will be ignored because YARN version on" +
" classpath does not support it.")
None
- }
}
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
index 83d30b7352..e7f7544664 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala
@@ -39,7 +39,7 @@ import org.apache.spark.util.Utils
/**
* Handles registering and unregistering the application with the YARN ResourceManager.
*/
-private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logging {
+private[spark] class YarnRMClient extends Logging {
private var amClient: AMRMClient[ContainerRequest] = _
private var uiHistoryAddress: String = _
@@ -72,8 +72,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg
amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress)
registered = true
}
- new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), args,
- securityMgr)
+ new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr)
}
/**
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
index 2915e664be..4b36da309d 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
@@ -135,8 +135,8 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil {
}
/**
- * Obtains token for the Hive metastore and adds them to the credentials.
- */
+ * Obtains token for the Hive metastore and adds them to the credentials.
+ */
def obtainTokenForHiveMetastore(
sparkConf: SparkConf,
conf: Configuration,
@@ -149,8 +149,8 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil {
}
/**
- * Obtain a security token for HBase.
- */
+ * Obtain a security token for HBase.
+ */
def obtainTokenForHBase(
sparkConf: SparkConf,
conf: Configuration,
@@ -164,10 +164,10 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil {
}
/**
- * Return whether delegation tokens should be retrieved for the given service when security is
- * enabled. By default, tokens are retrieved, but that behavior can be changed by setting
- * a service-specific configuration.
- */
+ * Return whether delegation tokens should be retrieved for the given service when security is
+ * enabled. By default, tokens are retrieved, but that behavior can be changed by setting
+ * a service-specific configuration.
+ */
private def shouldGetTokens(conf: SparkConf, service: String): Boolean = {
conf.getBoolean(s"spark.yarn.security.tokens.${service}.enabled", true)
}
@@ -512,7 +512,7 @@ object YarnSparkHadoopUtil {
val initialNumExecutors = conf.get(DYN_ALLOCATION_INITIAL_EXECUTORS)
val maxNumExecutors = conf.get(DYN_ALLOCATION_MAX_EXECUTORS)
require(initialNumExecutors >= minNumExecutors && initialNumExecutors <= maxNumExecutors,
- s"initial executor number $initialNumExecutors must between min executor number" +
+ s"initial executor number $initialNumExecutors must between min executor number " +
s"$minNumExecutors and max executor number $maxNumExecutors")
initialNumExecutors
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala
index 0789567ae6..edfbfc5d58 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala
@@ -31,75 +31,87 @@ package object config {
"in YARN Application Reports, which can be used for filtering when querying YARN.")
.stringConf
.toSequence
- .optional
+ .createOptional
private[spark] val ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS =
ConfigBuilder("spark.yarn.am.attemptFailuresValidityInterval")
.doc("Interval after which AM failures will be considered independent and " +
"not accumulate towards the attempt count.")
.timeConf(TimeUnit.MILLISECONDS)
- .optional
+ .createOptional
private[spark] val MAX_APP_ATTEMPTS = ConfigBuilder("spark.yarn.maxAppAttempts")
.doc("Maximum number of AM attempts before failing the app.")
.intConf
- .optional
+ .createOptional
private[spark] val USER_CLASS_PATH_FIRST = ConfigBuilder("spark.yarn.user.classpath.first")
.doc("Whether to place user jars in front of Spark's classpath.")
.booleanConf
- .withDefault(false)
+ .createWithDefault(false)
private[spark] val GATEWAY_ROOT_PATH = ConfigBuilder("spark.yarn.config.gatewayPath")
.doc("Root of configuration paths that is present on gateway nodes, and will be replaced " +
"with the corresponding path in cluster machines.")
.stringConf
- .withDefault(null)
+ .createWithDefault(null)
private[spark] val REPLACEMENT_ROOT_PATH = ConfigBuilder("spark.yarn.config.replacementPath")
.doc(s"Path to use as a replacement for ${GATEWAY_ROOT_PATH.key} when launching processes " +
"in the YARN cluster.")
.stringConf
- .withDefault(null)
+ .createWithDefault(null)
private[spark] val QUEUE_NAME = ConfigBuilder("spark.yarn.queue")
.stringConf
- .withDefault("default")
+ .createWithDefault("default")
private[spark] val HISTORY_SERVER_ADDRESS = ConfigBuilder("spark.yarn.historyServer.address")
.stringConf
- .optional
+ .createOptional
/* File distribution. */
private[spark] val SPARK_ARCHIVE = ConfigBuilder("spark.yarn.archive")
.doc("Location of archive containing jars files with Spark classes.")
.stringConf
- .optional
+ .createOptional
private[spark] val SPARK_JARS = ConfigBuilder("spark.yarn.jars")
.doc("Location of jars containing Spark classes.")
.stringConf
.toSequence
- .optional
+ .createOptional
private[spark] val ARCHIVES_TO_DISTRIBUTE = ConfigBuilder("spark.yarn.dist.archives")
.stringConf
- .optional
+ .toSequence
+ .createWithDefault(Nil)
private[spark] val FILES_TO_DISTRIBUTE = ConfigBuilder("spark.yarn.dist.files")
.stringConf
- .optional
+ .toSequence
+ .createWithDefault(Nil)
+
+ private[spark] val JARS_TO_DISTRIBUTE = ConfigBuilder("spark.yarn.dist.jars")
+ .stringConf
+ .toSequence
+ .createWithDefault(Nil)
private[spark] val PRESERVE_STAGING_FILES = ConfigBuilder("spark.yarn.preserve.staging.files")
.doc("Whether to preserve temporary files created by the job in HDFS.")
.booleanConf
- .withDefault(false)
+ .createWithDefault(false)
private[spark] val STAGING_FILE_REPLICATION = ConfigBuilder("spark.yarn.submit.file.replication")
.doc("Replication factor for files uploaded by Spark to HDFS.")
.intConf
- .optional
+ .createOptional
+
+ private[spark] val STAGING_DIR = ConfigBuilder("spark.yarn.stagingDir")
+ .doc("Staging directory used while submitting applications.")
+ .stringConf
+ .createOptional
/* Cluster-mode launcher configuration. */
@@ -107,143 +119,146 @@ package object config {
.doc("In cluster mode, whether to wait for the application to finish before exiting the " +
"launcher process.")
.booleanConf
- .withDefault(true)
+ .createWithDefault(true)
private[spark] val REPORT_INTERVAL = ConfigBuilder("spark.yarn.report.interval")
.doc("Interval between reports of the current app status in cluster mode.")
.timeConf(TimeUnit.MILLISECONDS)
- .withDefaultString("1s")
+ .createWithDefaultString("1s")
/* Shared Client-mode AM / Driver configuration. */
private[spark] val AM_MAX_WAIT_TIME = ConfigBuilder("spark.yarn.am.waitTime")
.timeConf(TimeUnit.MILLISECONDS)
- .withDefaultString("100s")
+ .createWithDefaultString("100s")
private[spark] val AM_NODE_LABEL_EXPRESSION = ConfigBuilder("spark.yarn.am.nodeLabelExpression")
.doc("Node label expression for the AM.")
.stringConf
- .optional
+ .createOptional
private[spark] val CONTAINER_LAUNCH_MAX_THREADS =
ConfigBuilder("spark.yarn.containerLauncherMaxThreads")
.intConf
- .withDefault(25)
+ .createWithDefault(25)
private[spark] val MAX_EXECUTOR_FAILURES = ConfigBuilder("spark.yarn.max.executor.failures")
.intConf
- .optional
+ .createOptional
private[spark] val MAX_REPORTER_THREAD_FAILURES =
ConfigBuilder("spark.yarn.scheduler.reporterThread.maxFailures")
.intConf
- .withDefault(5)
+ .createWithDefault(5)
private[spark] val RM_HEARTBEAT_INTERVAL =
ConfigBuilder("spark.yarn.scheduler.heartbeat.interval-ms")
.timeConf(TimeUnit.MILLISECONDS)
- .withDefaultString("3s")
+ .createWithDefaultString("3s")
private[spark] val INITIAL_HEARTBEAT_INTERVAL =
ConfigBuilder("spark.yarn.scheduler.initial-allocation.interval")
.timeConf(TimeUnit.MILLISECONDS)
- .withDefaultString("200ms")
+ .createWithDefaultString("200ms")
private[spark] val SCHEDULER_SERVICES = ConfigBuilder("spark.yarn.services")
.doc("A comma-separated list of class names of services to add to the scheduler.")
.stringConf
.toSequence
- .withDefault(Nil)
+ .createWithDefault(Nil)
/* Client-mode AM configuration. */
private[spark] val AM_CORES = ConfigBuilder("spark.yarn.am.cores")
.intConf
- .withDefault(1)
+ .createWithDefault(1)
private[spark] val AM_JAVA_OPTIONS = ConfigBuilder("spark.yarn.am.extraJavaOptions")
.doc("Extra Java options for the client-mode AM.")
.stringConf
- .optional
+ .createOptional
private[spark] val AM_LIBRARY_PATH = ConfigBuilder("spark.yarn.am.extraLibraryPath")
.doc("Extra native library path for the client-mode AM.")
.stringConf
- .optional
+ .createOptional
private[spark] val AM_MEMORY_OVERHEAD = ConfigBuilder("spark.yarn.am.memoryOverhead")
.bytesConf(ByteUnit.MiB)
- .optional
+ .createOptional
private[spark] val AM_MEMORY = ConfigBuilder("spark.yarn.am.memory")
.bytesConf(ByteUnit.MiB)
- .withDefaultString("512m")
+ .createWithDefaultString("512m")
/* Driver configuration. */
private[spark] val DRIVER_CORES = ConfigBuilder("spark.driver.cores")
.intConf
- .optional
+ .createWithDefault(1)
private[spark] val DRIVER_MEMORY_OVERHEAD = ConfigBuilder("spark.yarn.driver.memoryOverhead")
.bytesConf(ByteUnit.MiB)
- .optional
+ .createOptional
/* Executor configuration. */
+ private[spark] val EXECUTOR_CORES = ConfigBuilder("spark.executor.cores")
+ .intConf
+ .createWithDefault(1)
+
private[spark] val EXECUTOR_MEMORY_OVERHEAD = ConfigBuilder("spark.yarn.executor.memoryOverhead")
.bytesConf(ByteUnit.MiB)
- .optional
+ .createOptional
private[spark] val EXECUTOR_NODE_LABEL_EXPRESSION =
ConfigBuilder("spark.yarn.executor.nodeLabelExpression")
.doc("Node label expression for executors.")
.stringConf
- .optional
+ .createOptional
/* Security configuration. */
private[spark] val CREDENTIAL_FILE_MAX_COUNT =
ConfigBuilder("spark.yarn.credentials.file.retention.count")
.intConf
- .withDefault(5)
+ .createWithDefault(5)
private[spark] val CREDENTIALS_FILE_MAX_RETENTION =
ConfigBuilder("spark.yarn.credentials.file.retention.days")
.intConf
- .withDefault(5)
+ .createWithDefault(5)
private[spark] val NAMENODES_TO_ACCESS = ConfigBuilder("spark.yarn.access.namenodes")
.doc("Extra NameNode URLs for which to request delegation tokens. The NameNode that hosts " +
"fs.defaultFS does not need to be listed here.")
.stringConf
.toSequence
- .withDefault(Nil)
+ .createWithDefault(Nil)
private[spark] val TOKEN_RENEWAL_INTERVAL = ConfigBuilder("spark.yarn.token.renewal.interval")
- .internal
+ .internal()
.timeConf(TimeUnit.MILLISECONDS)
- .optional
+ .createOptional
/* Private configs. */
private[spark] val CREDENTIALS_FILE_PATH = ConfigBuilder("spark.yarn.credentials.file")
- .internal
+ .internal()
.stringConf
- .withDefault(null)
+ .createWithDefault(null)
// Internal config to propagate the location of the user's jar to the driver/executors
private[spark] val APP_JAR = ConfigBuilder("spark.yarn.user.jar")
- .internal
+ .internal()
.stringConf
- .optional
+ .createOptional
// Internal config to propagate the locations of any extra jars to add to the classpath
// of the executors
private[spark] val SECONDARY_JARS = ConfigBuilder("spark.yarn.secondary.jars")
- .internal
+ .internal()
.stringConf
.toSequence
- .optional
-
+ .createOptional
}
diff --git a/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala b/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala
index 7d246bf407..6c3556a2ee 100644
--- a/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala
+++ b/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala
@@ -19,6 +19,7 @@ package org.apache.spark.launcher
import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
+import scala.util.Properties
/**
* Exposes methods from the launcher library that are used by the YARN backend.
@@ -29,6 +30,14 @@ private[spark] object YarnCommandBuilderUtils {
CommandBuilderUtils.quoteForBatchScript(arg)
}
+ def findJarsDir(sparkHome: String): String = {
+ val scalaVer = Properties.versionNumberString
+ .split("\\.")
+ .take(2)
+ .mkString(".")
+ CommandBuilderUtils.findJarsDir(sparkHome, scalaVer, true)
+ }
+
/**
* Adds the perm gen configuration to the list of java options if needed and not yet added.
*
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
index 9fc727904b..56dc0004d0 100644
--- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -48,11 +48,10 @@ private[spark] class YarnClientSchedulerBackend(
val argsArrayBuf = new ArrayBuffer[String]()
argsArrayBuf += ("--arg", hostport)
- argsArrayBuf ++= getExtraClientArguments
logDebug("ClientArguments called with: " + argsArrayBuf.mkString(" "))
- val args = new ClientArguments(argsArrayBuf.toArray, conf)
- totalExpectedExecutors = args.numExecutors
+ val args = new ClientArguments(argsArrayBuf.toArray)
+ totalExpectedExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(conf)
client = new Client(args, conf)
bindToYarn(client.submitApplication(), None)
@@ -73,43 +72,6 @@ private[spark] class YarnClientSchedulerBackend(
}
/**
- * Return any extra command line arguments to be passed to Client provided in the form of
- * environment variables or Spark properties.
- */
- private def getExtraClientArguments: Seq[String] = {
- val extraArgs = new ArrayBuffer[String]
- // List of (target Client argument, environment variable, Spark property)
- val optionTuples =
- List(
- ("--executor-memory", "SPARK_WORKER_MEMORY", "spark.executor.memory"),
- ("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"),
- ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"),
- ("--executor-cores", "SPARK_EXECUTOR_CORES", "spark.executor.cores"),
- ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue"),
- ("--py-files", null, "spark.submit.pyFiles")
- )
- // Warn against the following deprecated environment variables: env var -> suggestion
- val deprecatedEnvVars = Map(
- "SPARK_WORKER_MEMORY" -> "SPARK_EXECUTOR_MEMORY or --executor-memory through spark-submit",
- "SPARK_WORKER_CORES" -> "SPARK_EXECUTOR_CORES or --executor-cores through spark-submit")
- optionTuples.foreach { case (optionName, envVar, sparkProp) =>
- if (sc.getConf.contains(sparkProp)) {
- extraArgs += (optionName, sc.getConf.get(sparkProp))
- } else if (envVar != null && System.getenv(envVar) != null) {
- extraArgs += (optionName, System.getenv(envVar))
- if (deprecatedEnvVars.contains(envVar)) {
- logWarning(s"NOTE: $envVar is deprecated. Use ${deprecatedEnvVars(envVar)} instead.")
- }
- }
- }
- // The app name is a special case because "spark.app.name" is required of all applications.
- // As a result, the corresponding "SPARK_YARN_APP_NAME" is already handled preemptively in
- // SparkSubmitArguments if "spark.app.name" is not explicitly set by the user. (SPARK-5222)
- sc.getConf.getOption("spark.app.name").foreach(v => extraArgs += ("--name", v))
- extraArgs
- }
-
- /**
* Report the state of the application until it is running.
* If the application has finished, failed or been killed in the process, throw an exception.
* This assumes both `client` and `appId` have already been set.
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
index a8781636f2..6b3c831e60 100644
--- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -39,9 +39,12 @@ private[spark] abstract class YarnSchedulerBackend(
sc: SparkContext)
extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) {
- if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) {
- minRegisteredRatio = 0.8
- }
+ override val minRegisteredRatio =
+ if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) {
+ 0.8
+ } else {
+ super.minRegisteredRatio
+ }
protected var totalExpectedExecutors = 0
@@ -220,17 +223,15 @@ private[spark] abstract class YarnSchedulerBackend(
val lossReasonRequest = GetExecutorLossReason(executorId)
val future = am.ask[ExecutorLossReason](lossReasonRequest, askTimeout)
future onSuccess {
- case reason: ExecutorLossReason => {
+ case reason: ExecutorLossReason =>
driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, reason))
- }
}
future onFailure {
- case NonFatal(e) => {
+ case NonFatal(e) =>
logWarning(s"Attempted to get executor loss reason" +
s" for executor id ${executorId} at RPC address ${executorRpcAddress}," +
s" but got no response. Marking as slave lost.", e)
driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, SlaveLost()))
- }
case t => throw t
}
case None =>
@@ -292,6 +293,9 @@ private[spark] abstract class YarnSchedulerBackend(
logWarning("Attempted to kill executors before the AM has registered!")
context.reply(false)
}
+
+ case RetrieveLastAllocatedExecutorId =>
+ context.reply(currentExecutorIdCounter)
}
override def onDisconnected(remoteAddress: RpcAddress): Unit = {
diff --git a/yarn/src/test/resources/log4j.properties b/yarn/src/test/resources/log4j.properties
index 6b9a799954..d13454d5ae 100644
--- a/yarn/src/test/resources/log4j.properties
+++ b/yarn/src/test/resources/log4j.properties
@@ -28,4 +28,4 @@ log4j.logger.com.sun.jersey=WARN
log4j.logger.org.apache.hadoop=WARN
log4j.logger.org.eclipse.jetty=WARN
log4j.logger.org.mortbay=WARN
-log4j.logger.org.spark-project.jetty=WARN
+log4j.logger.org.spark_project.jetty=WARN
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
index 2f3a31cb04..9c3b18e4ec 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
@@ -53,7 +53,7 @@ abstract class BaseYarnClusterSuite
|log4j.logger.org.apache.hadoop=WARN
|log4j.logger.org.eclipse.jetty=WARN
|log4j.logger.org.mortbay=WARN
- |log4j.logger.org.spark-project.jetty=WARN
+ |log4j.logger.org.spark_project.jetty=WARN
""".stripMargin
private var yarnCluster: MiniYARNCluster = _
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
index 24472e006b..74e268dc48 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.deploy.yarn
-import java.io.File
+import java.io.{File, FileOutputStream}
import java.net.URI
import java.util.Properties
@@ -118,10 +118,11 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll
val sparkConf = new SparkConf()
.set(SPARK_JARS, Seq(SPARK))
.set(USER_CLASS_PATH_FIRST, true)
+ .set("spark.yarn.dist.jars", ADDED)
val env = new MutableHashMap[String, String]()
- val args = new ClientArguments(Array("--jar", USER, "--addJars", ADDED), sparkConf)
+ val args = new ClientArguments(Array("--jar", USER))
- populateClasspath(args, conf, sparkConf, env, true)
+ populateClasspath(args, conf, sparkConf, env)
val cp = env("CLASSPATH").split(":|;|<CPS>")
s"$SPARK,$USER,$ADDED".split(",").foreach({ entry =>
@@ -138,9 +139,11 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll
}
test("Jar path propagation through SparkConf") {
- val sparkConf = new SparkConf().set(SPARK_JARS, Seq(SPARK))
- val client = createClient(sparkConf,
- args = Array("--jar", USER, "--addJars", ADDED))
+ val conf = new Configuration()
+ val sparkConf = new SparkConf()
+ .set(SPARK_JARS, Seq(SPARK))
+ .set("spark.yarn.dist.jars", ADDED)
+ val client = createClient(sparkConf, args = Array("--jar", USER))
val tempDir = Utils.createTempDir()
try {
@@ -178,8 +181,7 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll
"/remotePath/1:/remotePath/2")
val env = new MutableHashMap[String, String]()
- populateClasspath(null, conf, sparkConf, env, false,
- extraClassPath = Some("/localPath/my1.jar"))
+ populateClasspath(null, conf, sparkConf, env, extraClassPath = Some("/localPath/my1.jar"))
val cp = classpath(env)
cp should contain ("/remotePath/spark.jar")
cp should contain ("/remotePath/my1.jar")
@@ -193,9 +195,9 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll
val sparkConf = new SparkConf()
.set(APPLICATION_TAGS.key, ",tag1, dup,tag2 , ,multi word , dup")
.set(MAX_APP_ATTEMPTS, 42)
- val args = new ClientArguments(Array(
- "--name", "foo-test-app",
- "--queue", "staging-queue"), sparkConf)
+ .set("spark.app.name", "foo-test-app")
+ .set(QUEUE_NAME, "staging-queue")
+ val args = new ClientArguments(Array())
val appContext = Records.newRecord(classOf[ApplicationSubmissionContext])
val getNewApplicationResponse = Records.newRecord(classOf[GetNewApplicationResponse])
@@ -271,9 +273,10 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll
test("distribute local spark jars") {
val temp = Utils.createTempDir()
- val jarsDir = new File(temp, "lib")
+ val jarsDir = new File(temp, "jars")
assert(jarsDir.mkdir())
val jar = TestUtils.createJarWithFiles(Map(), jarsDir)
+ new FileOutputStream(new File(temp, "RELEASE")).close()
val sparkConf = new SparkConfWithEnv(Map("SPARK_HOME" -> temp.getAbsolutePath()))
val client = createClient(sparkConf)
@@ -346,7 +349,7 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll
sparkConf: SparkConf,
conf: Configuration = new Configuration(),
args: Array[String] = Array()): Client = {
- val clientArgs = new ClientArguments(args, sparkConf)
+ val clientArgs = new ClientArguments(args)
val client = spy(new Client(clientArgs, conf, sparkConf))
doReturn(new Path("/")).when(client).copyFileToRemote(any(classOf[Path]),
any(classOf[Path]), anyShort())
@@ -355,7 +358,7 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll
private def classpath(client: Client): Array[String] = {
val env = new MutableHashMap[String, String]()
- populateClasspath(null, client.hadoopConf, client.sparkConf, env, false)
+ populateClasspath(null, client.hadoopConf, client.sparkConf, env)
classpath(env)
}
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
index 0587444a33..a641a6e73e 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
@@ -90,12 +90,13 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter
def createAllocator(maxExecutors: Int = 5): YarnAllocator = {
val args = Array(
- "--executor-cores", "5",
- "--executor-memory", "2048",
"--jar", "somejar.jar",
"--class", "SomeClass")
val sparkConfClone = sparkConf.clone()
- sparkConfClone.set("spark.executor.instances", maxExecutors.toString)
+ sparkConfClone
+ .set("spark.executor.instances", maxExecutors.toString)
+ .set("spark.executor.cores", "5")
+ .set("spark.executor.memory", "2048")
new YarnAllocator(
"not used",
mock(classOf[RpcEndpointRef]),
@@ -103,7 +104,6 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter
sparkConfClone,
rmClient,
appAttemptId,
- new ApplicationMasterArguments(args),
new SecurityManager(sparkConf))
}
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index 26520529ec..b2b4d84f53 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -85,6 +85,35 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
testBasicYarnApp(false)
}
+ test("run Spark in yarn-client mode with different configurations") {
+ testBasicYarnApp(true,
+ Map(
+ "spark.driver.memory" -> "512m",
+ "spark.executor.cores" -> "1",
+ "spark.executor.memory" -> "512m",
+ "spark.executor.instances" -> "2"
+ ))
+ }
+
+ test("run Spark in yarn-cluster mode with different configurations") {
+ testBasicYarnApp(true,
+ Map(
+ "spark.driver.memory" -> "512m",
+ "spark.driver.cores" -> "1",
+ "spark.executor.cores" -> "1",
+ "spark.executor.memory" -> "512m",
+ "spark.executor.instances" -> "2"
+ ))
+ }
+
+ test("run Spark in yarn-client mode with additional jar") {
+ testWithAddJar(true)
+ }
+
+ test("run Spark in yarn-cluster mode with additional jar") {
+ testWithAddJar(false)
+ }
+
test("run Spark in yarn-cluster mode unsuccessfully") {
// Don't provide arguments so the driver will fail.
val finalState = runSpark(false, mainClassName(YarnClusterDriver.getClass))
@@ -139,13 +168,26 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
}
}
- private def testBasicYarnApp(clientMode: Boolean): Unit = {
+ private def testBasicYarnApp(clientMode: Boolean, conf: Map[String, String] = Map()): Unit = {
val result = File.createTempFile("result", null, tempDir)
val finalState = runSpark(clientMode, mainClassName(YarnClusterDriver.getClass),
- appArgs = Seq(result.getAbsolutePath()))
+ appArgs = Seq(result.getAbsolutePath()),
+ extraConf = conf)
checkResult(finalState, result)
}
+ private def testWithAddJar(clientMode: Boolean): Unit = {
+ val originalJar = TestUtils.createJarWithFiles(Map("test.resource" -> "ORIGINAL"), tempDir)
+ val driverResult = File.createTempFile("driver", null, tempDir)
+ val executorResult = File.createTempFile("executor", null, tempDir)
+ val finalState = runSpark(clientMode, mainClassName(YarnClasspathTest.getClass),
+ appArgs = Seq(driverResult.getAbsolutePath(), executorResult.getAbsolutePath()),
+ extraClassPath = Seq(originalJar.getPath()),
+ extraJars = Seq("local:" + originalJar.getPath()))
+ checkResult(finalState, driverResult, "ORIGINAL")
+ checkResult(finalState, executorResult, "ORIGINAL")
+ }
+
private def testPySpark(clientMode: Boolean): Unit = {
val primaryPyFile = new File(tempDir, "test.py")
Files.write(TEST_PYFILE, primaryPyFile, StandardCharsets.UTF_8)
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
index de14e36f4e..fe09808ae5 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
@@ -101,22 +101,18 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging
val modifyAcls = acls.get(ApplicationAccessType.MODIFY_APP)
viewAcls match {
- case Some(vacls) => {
+ case Some(vacls) =>
val aclSet = vacls.split(',').map(_.trim).toSet
assert(aclSet.contains(System.getProperty("user.name", "invalid")))
- }
- case None => {
+ case None =>
fail()
- }
}
modifyAcls match {
- case Some(macls) => {
+ case Some(macls) =>
val aclSet = macls.split(',').map(_.trim).toSet
assert(aclSet.contains(System.getProperty("user.name", "invalid")))
- }
- case None => {
+ case None =>
fail()
- }
}
}
@@ -135,26 +131,22 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging
val modifyAcls = acls.get(ApplicationAccessType.MODIFY_APP)
viewAcls match {
- case Some(vacls) => {
+ case Some(vacls) =>
val aclSet = vacls.split(',').map(_.trim).toSet
assert(aclSet.contains("user1"))
assert(aclSet.contains("user2"))
assert(aclSet.contains(System.getProperty("user.name", "invalid")))
- }
- case None => {
+ case None =>
fail()
- }
}
modifyAcls match {
- case Some(macls) => {
+ case Some(macls) =>
val aclSet = macls.split(',').map(_.trim).toSet
assert(aclSet.contains("user3"))
assert(aclSet.contains("user4"))
assert(aclSet.contains(System.getProperty("user.name", "invalid")))
- }
- case None => {
+ case None =>
fail()
- }
}
}