aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
blob: c1ebf8b09e08d49ca3a45f59d87dc2f9829e09f1 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class JoinOptimizationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Subqueries", Once,
        EliminateSubqueryAliases) ::
      Batch("Filter Pushdown", FixedPoint(100),
        CombineFilters,
        PushDownPredicate,
        BooleanSimplification,
        ReorderJoin,
        PushPredicateThroughJoin,
        ColumnPruning,
        CollapseProject) :: Nil

  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
  val testRelation1 = LocalRelation('d.int)

  test("extract filters and joins") {
    val x = testRelation.subquery('x)
    val y = testRelation1.subquery('y)
    val z = testRelation.subquery('z)

    def testExtract(plan: LogicalPlan, expected: Option[(Seq[LogicalPlan], Seq[Expression])]) {
      assert(ExtractFiltersAndInnerJoins.unapply(plan) === expected)
    }

    testExtract(x, None)
    testExtract(x.where("x.b".attr === 1), None)
    testExtract(x.join(y), Some(Seq(x, y), Seq()))
    testExtract(x.join(y, condition = Some("x.b".attr === "y.d".attr)),
      Some(Seq(x, y), Seq("x.b".attr === "y.d".attr)))
    testExtract(x.join(y).where("x.b".attr === "y.d".attr),
      Some(Seq(x, y), Seq("x.b".attr === "y.d".attr)))
    testExtract(x.join(y).join(z), Some(Seq(x, y, z), Seq()))
    testExtract(x.join(y).where("x.b".attr === "y.d".attr).join(z),
      Some(Seq(x, y, z), Seq("x.b".attr === "y.d".attr)))
    testExtract(x.join(y).join(x.join(z)), Some(Seq(x, y, x.join(z)), Seq()))
    testExtract(x.join(y).join(x.join(z)).where("x.b".attr === "y.d".attr),
      Some(Seq(x, y, x.join(z)), Seq("x.b".attr === "y.d".attr)))
  }

  test("reorder inner joins") {
    val x = testRelation.subquery('x)
    val y = testRelation1.subquery('y)
    val z = testRelation.subquery('z)

    val originalQuery = {
      x.join(y).join(z)
        .where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr))
    }

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      x.join(z, condition = Some("x.b".attr === "z.b".attr))
        .join(y, condition = Some("y.d".attr === "z.a".attr))
        .analyze

    comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer))
  }

  test("broadcasthint sets relation statistics to smallest value") {
    val input = LocalRelation('key.int, 'value.string)

    val query =
      Project(Seq($"x.key", $"y.key"),
        Join(
          SubqueryAlias("x", input),
          BroadcastHint(SubqueryAlias("y", input)), Inner, None)).analyze

    val optimized = Optimize.execute(query)

    val expected =
      Join(
        Project(Seq($"x.key"), SubqueryAlias("x", input)),
        BroadcastHint(Project(Seq($"y.key"), SubqueryAlias("y", input))),
        Inner, None).analyze

    comparePlans(optimized, expected)

    val broadcastChildren = optimized.collect {
      case Join(_, r, _, _) if r.statistics.sizeInBytes == 1 => r
    }
    assert(broadcastChildren.size == 1)
  }
}