| 
1 | 1 | package dotty.tools.dotc  | 
2 | 2 | package transform.localopt  | 
3 | 3 | 
 
  | 
 | 4 | +import dotty.tools.dotc.ast.desugar.TrailingForMap  | 
4 | 5 | import dotty.tools.dotc.ast.tpd.*  | 
5 |  | -import dotty.tools.dotc.core.Decorators.*  | 
6 | 6 | import dotty.tools.dotc.core.Contexts.*  | 
 | 7 | +import dotty.tools.dotc.core.Decorators.*  | 
 | 8 | +import dotty.tools.dotc.core.Flags.*  | 
7 | 9 | import dotty.tools.dotc.core.StdNames.*  | 
8 | 10 | import dotty.tools.dotc.core.Symbols.*  | 
9 | 11 | import dotty.tools.dotc.core.Types.*  | 
10 | 12 | import dotty.tools.dotc.transform.MegaPhase.MiniPhase  | 
11 |  | -import dotty.tools.dotc.ast.desugar  | 
12 | 13 | 
 
  | 
13 | 14 | /** Drop unused trailing map calls in for comprehensions.  | 
14 |  | -  * We can drop the map call if:  | 
15 |  | -  * - it won't change the type of the expression, and  | 
16 |  | -  * - the function is an identity function or a const function to unit.  | 
17 |  | -  *  | 
18 |  | -  * The latter condition is checked in [[Desugar.scala#makeFor]]  | 
19 |  | -  */  | 
 | 15 | + *  | 
 | 16 | + *  We can drop the map call if:  | 
 | 17 | + *  - it won't change the type of the expression, and  | 
 | 18 | + *  - the function is an identity function or a const function to unit.  | 
 | 19 | + *  | 
 | 20 | + *  The latter condition is checked in [[Desugar.scala#makeFor]]  | 
 | 21 | + */  | 
20 | 22 | class DropForMap extends MiniPhase:  | 
21 |  | -  import DropForMap.*  | 
22 | 23 | 
 
  | 
23 | 24 |   override def phaseName: String = DropForMap.name  | 
24 | 25 | 
 
  | 
25 | 26 |   override def description: String = DropForMap.description  | 
26 | 27 | 
 
  | 
27 |  | -  override def transformApply(tree: Apply)(using Context): Tree =  | 
28 |  | -    if !tree.hasAttachment(desugar.TrailingForMap) then tree  | 
29 |  | -    else tree match  | 
30 |  | -      case aply @ Apply(MapCall(f), List(Lambda(List(param), body)))  | 
31 |  | -      if f.tpe =:= aply.tpe => // make sure that the type of the expression won't change  | 
 | 28 | +  import DropForMap.{Converted, Unmapped}  | 
 | 29 | + | 
 | 30 | +  /** r.map(x => x)(using y) --> r  | 
 | 31 | +   *       ^ TrailingForMap  | 
 | 32 | +   */  | 
 | 33 | +  override def transformApply(tree: Apply)(using Context): Tree = tree match  | 
 | 34 | +    case Unmapped(f0, sym, args) =>  | 
 | 35 | +      val f =  | 
 | 36 | +        if sym.is(Extension) then args.head  | 
 | 37 | +        else f0  | 
 | 38 | +      if f.tpe.widen =:= tree.tpe then // make sure that the type of the expression won't change  | 
32 | 39 |         f // drop the map call  | 
33 |  | -      case _ =>  | 
34 |  | -        tree.removeAttachment(desugar.TrailingForMap)  | 
35 |  | -        tree  | 
 | 40 | +      else  | 
 | 41 | +        f match  | 
 | 42 | +        case Converted(r) if r.tpe =:= tree.tpe => r // drop the map call and the conversion  | 
 | 43 | +        case _ => tree  | 
 | 44 | +    case tree => tree  | 
36 | 45 | 
 
  | 
37 |  | -  private object Lambda:  | 
38 |  | -    def unapply(tree: Tree)(using Context): Option[(List[ValDef], Tree)] =  | 
39 |  | -      tree match  | 
40 |  | -        case Block(List(defdef: DefDef), Closure(Nil, ref, _))  | 
41 |  | -        if ref.symbol == defdef.symbol && !defdef.paramss.exists(_.forall(_.isType)) =>  | 
42 |  | -          Some((defdef.termParamss.flatten, defdef.rhs))  | 
 | 46 | +  /** If the map was inlined, fetch the binding for the receiver,  | 
 | 47 | +   *  then find the tree in the expansion that refers to the binding.  | 
 | 48 | +   *  That is the expansion of the result Inlined node.  | 
 | 49 | +   */  | 
 | 50 | +  override def transformInlined(tree: Inlined)(using Context): Tree = tree match  | 
 | 51 | +    case Inlined(call, bindings, expansion) if call.hasAttachment(TrailingForMap) =>  | 
 | 52 | +      val expansion1 =  | 
 | 53 | +        call match  | 
 | 54 | +        case Unmapped(f0, sym, args) =>  | 
 | 55 | +          val f =  | 
 | 56 | +            if sym.is(Extension) then args.head  | 
 | 57 | +            else f0  | 
 | 58 | +          if f.tpe.widen =:= expansion.tpe then  | 
 | 59 | +            bindings.collectFirst:  | 
 | 60 | +              case vd: ValDef if f.sameTree(vd.rhs) =>  | 
 | 61 | +                expansion.find:  | 
 | 62 | +                  case Inlined(Thicket(Nil), Nil, Ident(ident)) => ident == vd.name  | 
 | 63 | +                  case _ => false  | 
 | 64 | +                .getOrElse(expansion)  | 
 | 65 | +            .getOrElse(expansion)  | 
 | 66 | +          else  | 
 | 67 | +            f match  | 
 | 68 | +            case Converted(r) if r.tpe =:= expansion.tpe => r // drop the map call and the conversion  | 
 | 69 | +            case _ => expansion  | 
 | 70 | +        case _ => expansion  | 
 | 71 | +      if expansion1 ne expansion then  | 
 | 72 | +        cpy.Inlined(tree)(call, bindings, expansion1)  | 
 | 73 | +      else tree  | 
 | 74 | +    case tree => tree  | 
 | 75 | + | 
 | 76 | +object DropForMap:  | 
 | 77 | +  val name: String = "dropForMap"  | 
 | 78 | +  val description: String = "Drop unused trailing map calls in for comprehensions"  | 
 | 79 | + | 
 | 80 | +  // Extracts a fun from a possibly nested Apply with lambda and arbitrary implicit args.  | 
 | 81 | +  // Specifically, an application `r.map(x => x)` is destructured into (r, map, args).  | 
 | 82 | +  // If the receiver r was adapted, it is unwrapped.  | 
 | 83 | +  // If `map` is an extension method, the nominal receiver is `args.head`.  | 
 | 84 | +  private object Unmapped:  | 
 | 85 | +    private def loop(tree: Tree, args: List[Tree])(using Context): Option[(Tree, Symbol, List[Tree])] = tree match  | 
 | 86 | +      case Apply(fun, args @ Lambda(_ :: Nil, _) :: Nil) =>  | 
 | 87 | +        tree.removeAttachment(TrailingForMap) match  | 
 | 88 | +        case Some(_) =>  | 
 | 89 | +          fun match  | 
 | 90 | +          case MapCall(f, sym, args) => Some((f, sym, args))  | 
 | 91 | +          case _ => None  | 
 | 92 | +        case _ => None  | 
 | 93 | +      case Apply(fun, _) =>  | 
 | 94 | +        fun.tpe match  | 
 | 95 | +        case mt: MethodType if mt.isImplicitMethod => loop(fun, args)  | 
43 | 96 |         case _ => None  | 
 | 97 | +      case TypeApply(fun, _) => loop(fun, args)  | 
 | 98 | +      case _ => None  | 
 | 99 | +    end loop  | 
 | 100 | +    def unapply(tree: Apply)(using Context): Option[(Tree, Symbol, List[Tree])] =  | 
 | 101 | +      tree.tpe match  | 
 | 102 | +      case _: MethodOrPoly => None  | 
 | 103 | +      case _ => loop(tree, args = Nil)  | 
 | 104 | + | 
 | 105 | +  private object Lambda:  | 
 | 106 | +    def unapply(tree: Tree)(using Context): Option[(List[ValDef], Tree)] = tree match  | 
 | 107 | +      case Block(List(defdef: DefDef), Closure(Nil, ref, _))  | 
 | 108 | +      if ref.symbol == defdef.symbol && !defdef.paramss.exists(_.forall(_.isType)) =>  | 
 | 109 | +        Some((defdef.termParamss.flatten, defdef.rhs))  | 
 | 110 | +      case _ => None  | 
44 | 111 | 
 
  | 
45 | 112 |   private object MapCall:  | 
 | 113 | +    def unapply(tree: Tree)(using Context): Option[(Tree, Symbol, List[Tree])] =  | 
 | 114 | +      def loop(tree: Tree, args: List[Tree]): Option[(Tree, Symbol, List[Tree])] =  | 
 | 115 | +        tree match  | 
 | 116 | +        case Ident(nme.map) if tree.symbol.is(Extension) => Some((EmptyTree, tree.symbol, args))  | 
 | 117 | +        case Select(f, nme.map) => Some((f, tree.symbol, args))  | 
 | 118 | +        case Apply(fn, args) => loop(fn, args)  | 
 | 119 | +        case TypeApply(fn, _) => loop(fn, args)  | 
 | 120 | +        case _ => None  | 
 | 121 | +      loop(tree, Nil)  | 
 | 122 | + | 
 | 123 | +  private object Converted:  | 
46 | 124 |     def unapply(tree: Tree)(using Context): Option[Tree] = tree match  | 
47 |  | -      case Select(f, nme.map) => Some(f)  | 
48 |  | -      case Apply(fn, _) => unapply(fn)  | 
 | 125 | +      case Apply(fn @ Apply(_, _), _) => unapply(fn)  | 
 | 126 | +      case Apply(fn, r :: Nil)  | 
 | 127 | +      if fn.symbol.is(Implicit) || fn.symbol.name == nme.apply && fn.symbol.owner.derivesFrom(defn.ConversionClass)  | 
 | 128 | +      => Some(r)  | 
49 | 129 |       case TypeApply(fn, _) => unapply(fn)  | 
50 | 130 |       case _ => None  | 
51 |  | - | 
52 |  | -object DropForMap:  | 
53 |  | -  val name: String = "dropForMap"  | 
54 |  | -  val description: String = "Drop unused trailing map calls in for comprehensions"  | 
 | 
0 commit comments