@@ -104,40 +104,59 @@ def build_subgraph_buffer(
104104 args: The args that are passed into the subgraph. Contains both fixed and lifted inputs.
105105 subgraph: The Subgraph ir for which to produce the output node
106106 """
107- from ..subgraph_lowering import PointwiseSubgraphLowering
108-
109- pw_subgraph = PointwiseSubgraphLowering (
110- subgraph .graph_module , root_graph_lowering = V .graph
111- )
112- with V .set_graph_handler (pw_subgraph ): # type: ignore[arg-type]
113- pw_subgraph .run (* args )
114-
115- def convert_output_node_to_buffer (output ):
116- if output is None :
117- return None
118- output_buffer = output
119- assert isinstance (output_buffer , TensorBox ), (
120- "The output node for flex attention's subgraph must be a TensorBox, but got: " ,
121- type (output_buffer ),
122- )
123- assert isinstance (output_buffer .data , StorageBox ), (
124- "The output node for the flex attention subgraph must be a StorageBox, but got: " ,
125- type (output_buffer ),
126- )
127- subgraph_buffer = ComputedBuffer (
128- name = None ,
129- layout = FlexibleLayout (
130- device = output_buffer .data .get_device (),
131- dtype = output_buffer .data .get_dtype (),
132- size = output_buffer .data .get_size (),
133- ),
134- data = output_buffer .data .data , # type: ignore[arg-type]
135- )
136- return subgraph_buffer
137-
138- # node.args[0] is either a single element or a list of elements
139- # representing all outputs of the function.
140- return tree_map (convert_output_node_to_buffer , pw_subgraph .graph_outputs )
107+ cnt = 0
108+ env = {}
109+ for node in subgraph .graph_module .graph .nodes :
110+ # There are two classes of placeholder inpts that we need
111+ # to handle differently. For the first n_scalar_inps inputs
112+ # we expect that these placeholders were generated by the make_fx call
113+ # in the flex Attention HOP. So we need to create a new placeholder
114+ # TensorBox for each of these inputs. For the rest of the inputs we
115+ # expect that these are lifted inputs that fill up the '*other_buffers'
116+ # tuple and already have corresponding TensorBoxes passed in as args.
117+ with V .graph .set_current_node (node ):
118+ if node .op == "placeholder" :
119+ env [node ] = args [cnt ]
120+ cnt += 1
121+ elif node .op == "call_function" :
122+ # For call_function we use the default lowerings and pass in the
123+ # already created TensorBoxes as args
124+
125+ args , kwargs = tree_map (
126+ lambda x : env [x ] if x in env else x , (node .args , node .kwargs )
127+ )
128+ env [node ] = lowerings [node .target ](* args , ** kwargs )
129+ elif node .op == "output" :
130+
131+ def convert_output_node_to_buffer (output ):
132+ if output is None :
133+ return None
134+ output_node = output
135+ output_buffer = env [output_node ]
136+ assert isinstance (output_buffer , TensorBox ), (
137+ "The output node for flex attention's subgraph must be a TensorBox, but got: " ,
138+ type (output_buffer ),
139+ )
140+ assert isinstance (output_buffer .data , StorageBox ), (
141+ "The output node for the flex attention subgraph must be a StorageBox, but got: " ,
142+ type (output_buffer ),
143+ )
144+ subgraph_buffer = ComputedBuffer (
145+ name = None ,
146+ layout = FlexibleLayout (
147+ device = output_buffer .data .get_device (),
148+ dtype = output_buffer .data .get_dtype (),
149+ size = output_buffer .data .get_size (),
150+ ),
151+ data = output_buffer .data .data , # type: ignore[arg-type]
152+ )
153+ return subgraph_buffer
154+
155+ # node.args[0] is either a single element or a list of elements
156+ # representing all outputs of the function.
157+ return tree_map (convert_output_node_to_buffer , node .args [0 ])
158+
159+ raise ValueError ("FlexAttention was passed a subgraph with no output node!" )
141160
142161
143162# Inner Triton functions shared by flex_attention & split-k decoding kernels.
@@ -503,7 +522,7 @@ def forward_block_mn(
503522 ) | indent_except_first(2) }}
504523
505524 if CHECK_BLOCK_BOUNDARY:
506- mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False )
525+ mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, float("-inf") )
507526 # apply mask for partially unmasked blocks
508527 post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
509528
@@ -1739,8 +1758,6 @@ def flex_attention_backward(*args, **kwargs):
17391758 joint_placeholder_inps = fwd_placeholder_inps + [
17401759 create_placeholder ("grad_score_mod" , dtype , device )
17411760 ]
1742- # Sometimes we have weird unused nodes here
1743- joint_graph .graph_module .graph .eliminate_dead_code ()
17441761 joint_subgraph_buffer , * _ = build_subgraph_buffer (
17451762 joint_placeholder_inps + list (score_mod_other_buffers ), joint_graph
17461763 )
0 commit comments