Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion llvm/lib/Target/AMDGPU/AMDGPUCombine.td
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,20 @@ def sign_extension_in_reg : GICombineRule<
[{ return matchCombineSignExtendInReg(*${sign_inreg}, ${matchinfo}); }]),
(apply [{ applyCombineSignExtendInReg(*${sign_inreg}, ${matchinfo}); }])>;

def cond_sub_to_fma_matchdata : GIDefMatchData<"ConditionalSubToFMAMatchInfo">;

// Optimize conditional subtraction patterns to FMA:
// result = a - (cond ? c : 0.0) -> fma(select(cond, -1.0, 0.0), c, a)
// result = a + (cond ? -c : 0.0) -> fma(select(cond, -1.0, 0.0), c, a)
// result = a + (-(cond ? c : 0.0)) -> fma(select(cond, -1.0, 0.0), c, a)
//
// Only enabled for f64 when hasFmacF64Inst() is true.
def cond_sub_to_fma : GICombineRule<
(defs root:$fsub_or_fadd, cond_sub_to_fma_matchdata:$matchinfo),
(match (wip_match_opcode G_FSUB, G_FADD):$fsub_or_fadd,
[{ return matchConditionalSubToFMA(*${fsub_or_fadd}, ${matchinfo}); }]),
(apply [{ applyConditionalSubToFMA(*${fsub_or_fadd}, ${matchinfo}); }])>;

// Do the following combines :
// fmul x, select(y, A, B) -> fldexp (x, select i32 (y, a, b))
// fmul x, select(y, -A, -B) -> fldexp ((fneg x), select i32 (y, a, b))
Expand Down Expand Up @@ -228,7 +242,7 @@ def AMDGPUPostLegalizerCombiner: GICombiner<
[all_combines, gfx6gfx7_combines, gfx8_combines, combine_fmul_with_select_to_fldexp,
uchar_to_float, cvt_f32_ubyteN, remove_fcanonicalize, foldable_fneg,
rcp_sqrt_to_rsq, fdiv_by_sqrt_to_rsq_f16, sign_extension_in_reg, smulu64,
binop_s64_with_s32_mask_combines, combine_or_s64_s32]> {
cond_sub_to_fma, binop_s64_with_s32_mask_combines, combine_or_s64_s32]> {
let CombineAllMethodName = "tryCombineAllImpl";
}

Expand Down
134 changes: 130 additions & 4 deletions llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
#include "llvm/CodeGen/MachineDominators.h"
#include "llvm/CodeGen/MachineLoopInfo.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/Target/TargetMachine.h"
Expand All @@ -47,6 +48,7 @@ class AMDGPUPostLegalizerCombinerImpl : public Combiner {
const AMDGPUPostLegalizerCombinerImplRuleConfig &RuleConfig;
const GCNSubtarget &STI;
const SIInstrInfo &TII;
const MachineLoopInfo *MLI;
// TODO: Make CombinerHelper methods const.
mutable AMDGPUCombinerHelper Helper;

Expand All @@ -56,7 +58,7 @@ class AMDGPUPostLegalizerCombinerImpl : public Combiner {
GISelValueTracking &VT, GISelCSEInfo *CSEInfo,
const AMDGPUPostLegalizerCombinerImplRuleConfig &RuleConfig,
const GCNSubtarget &STI, MachineDominatorTree *MDT,
const LegalizerInfo *LI);
const MachineLoopInfo *MLI, const LegalizerInfo *LI);

static const char *getName() { return "AMDGPUPostLegalizerCombinerImpl"; }

Expand Down Expand Up @@ -113,6 +115,18 @@ class AMDGPUPostLegalizerCombinerImpl : public Combiner {
// bits are zero extended.
bool matchCombine_s_mul_u64(MachineInstr &MI, unsigned &NewOpcode) const;

// Match conditional subtraction patterns for FMA optimization
struct ConditionalSubToFMAMatchInfo {
Register Cond;
Register C;
Register A;
};

bool matchConditionalSubToFMA(MachineInstr &MI,
ConditionalSubToFMAMatchInfo &MatchInfo) const;
void applyConditionalSubToFMA(MachineInstr &MI,
const ConditionalSubToFMAMatchInfo &MatchInfo) const;

private:
#define GET_GICOMBINER_CLASS_MEMBERS
#define AMDGPUSubtarget GCNSubtarget
Expand All @@ -131,9 +145,10 @@ AMDGPUPostLegalizerCombinerImpl::AMDGPUPostLegalizerCombinerImpl(
MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC,
GISelValueTracking &VT, GISelCSEInfo *CSEInfo,
const AMDGPUPostLegalizerCombinerImplRuleConfig &RuleConfig,
const GCNSubtarget &STI, MachineDominatorTree *MDT, const LegalizerInfo *LI)
const GCNSubtarget &STI, MachineDominatorTree *MDT,
const MachineLoopInfo *MLI, const LegalizerInfo *LI)
: Combiner(MF, CInfo, TPC, &VT, CSEInfo), RuleConfig(RuleConfig), STI(STI),
TII(*STI.getInstrInfo()),
TII(*STI.getInstrInfo()), MLI(MLI),
Helper(Observer, B, /*IsPreLegalize*/ false, &VT, MDT, LI, STI),
#define GET_GICOMBINER_CONSTRUCTOR_INITS
#include "AMDGPUGenPostLegalizeGICombiner.inc"
Expand Down Expand Up @@ -435,6 +450,112 @@ bool AMDGPUPostLegalizerCombinerImpl::matchCombine_s_mul_u64(
return false;
}

// Match conditional subtraction patterns for FMA optimization.
//
// This function identifies patterns like:
// result = a - (cond ? c : 0.0)
// result = a + (cond ? -c : 0.0)
// result = a + (-(cond ? c : 0.0))
//
// These can be converted to an efficient FMA:
// result = fma((cond ? -1.0, 0.0), c, a)
//
bool AMDGPUPostLegalizerCombinerImpl::matchConditionalSubToFMA(
MachineInstr &MI, ConditionalSubToFMAMatchInfo &MatchInfo) const {
// Only optimize f64 with FMAC support, and check VOPD constraints.
if (!MLI || !STI.shouldUseConditionalSubToFMAF64())
return false;

Register DstReg = MI.getOperand(0).getReg();
LLT Ty = MRI.getType(DstReg);
if (Ty != LLT::scalar(64))
return false;

Register A = MI.getOperand(1).getReg();
Register RHS = MI.getOperand(2).getReg();
MachineInstr *RHSMI = MRI.getVRegDef(RHS);
if (!RHSMI)
return false;

// Returns true if SelMI is a valid select with false value = 0.0.
auto matchSelectWithZero = [this, &MI](MachineInstr *SelMI, Register &Cond,
Register &TrueVal) -> bool {
if (!SelMI || SelMI->getOpcode() != TargetOpcode::G_SELECT)
return false;

// Check if FalseVal is exactly 0.0.
Register FalseVal = SelMI->getOperand(3).getReg();
auto FalseConst = getFConstantVRegValWithLookThrough(FalseVal, MRI);
if (!FalseConst || !FalseConst->Value.isExactlyValue(0.0))
return false;

// Check if TrueVal is not constant.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider updating the comment to reflect that constant support is deferred:

// Check if TrueVal is not constant.
// TODO: Support constants in a follow-up patch. Currently disabled due to
// codegen issues that need investigation.
if (TrueConst)
  return false;

auto TempTrueVal = SelMI->getOperand(2).getReg();
auto TrueConst = getAnyConstantVRegValWithLookThrough(TempTrueVal, MRI);
if (TrueConst)
return false;

// Check if select and the add/sub are in same loop context.
Copy link

@michaelselehov michaelselehov Nov 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: Consider adding a comment explaining why we skip loop-invariant selects:

// Check if select and the add/sub are in same loop context.
// Even for loop-invariant select (which might get hoisted), we skip
// the optimization because it wouldn't provide benefit in the loop body
// (same 1 instruction, but worse register pressure: 2 vs 4+ registers).
if (MLI->getLoopFor(MI.getParent()) != MLI->getLoopFor(SelMI->getParent()))
  return false;

This will help future maintainers understand the reasoning.

if (MLI->getLoopFor(MI.getParent()) != MLI->getLoopFor(SelMI->getParent()))
return false;

TrueVal = TempTrueVal;
Cond = SelMI->getOperand(1).getReg();
return true;
};

Register Cond, C;
if (MI.getOpcode() == TargetOpcode::G_FSUB) {
// Pattern: fsub a, (select cond, c, 0.0)
if (matchSelectWithZero(RHSMI, Cond, C)) {
MatchInfo = {Cond, C, A};
return true;
}
} else if (MI.getOpcode() == TargetOpcode::G_FADD) {
// Pattern 1: fadd a, (fneg (select cond, c, 0.0))
if (RHSMI->getOpcode() == TargetOpcode::G_FNEG) {
Register SelReg = RHSMI->getOperand(1).getReg();
MachineInstr *SelMI = MRI.getVRegDef(SelReg);
if (matchSelectWithZero(SelMI, Cond, C)) {
MatchInfo = {Cond, C, A};
return true;
}
}

// Pattern 2: fadd a, (select cond, (fneg c), 0.0)
if (matchSelectWithZero(RHSMI, Cond, C)) {
// Check if C is fneg
MachineInstr *CMI = MRI.getVRegDef(C);
if (CMI && CMI->getOpcode() == TargetOpcode::G_FNEG) {
C = CMI->getOperand(1).getReg();
MatchInfo = {Cond, C, A};
return true;
}
}
}
return false;
}

void AMDGPUPostLegalizerCombinerImpl::applyConditionalSubToFMA(
MachineInstr &MI, const ConditionalSubToFMAMatchInfo &MatchInfo) const {
Register Dst = MI.getOperand(0).getReg();
LLT Ty = MRI.getType(Dst);

// Build: correction = select cond, -1.0, 0.0
APFloat MinusOne = APFloat(-1.0);
APFloat Zero = APFloat(0.0);

Register MinusOneReg = B.buildFConstant(Ty, MinusOne).getReg(0);
Register ZeroReg = B.buildFConstant(Ty, Zero).getReg(0);
Register Correction =
B.buildSelect(Ty, MatchInfo.Cond, MinusOneReg, ZeroReg).getReg(0);

// Build: result = fma(correction, c, a)
B.buildFMA(Dst, Correction, MatchInfo.C, MatchInfo.A, MI.getFlags());

MI.eraseFromParent();
}

// Pass boilerplate
// ================

Expand Down Expand Up @@ -467,6 +588,8 @@ void AMDGPUPostLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const {
if (!IsOptNone) {
AU.addRequired<MachineDominatorTreeWrapperPass>();
AU.addPreserved<MachineDominatorTreeWrapperPass>();
AU.addRequired<MachineLoopInfoWrapperPass>();
AU.addPreserved<MachineLoopInfoWrapperPass>();
}
MachineFunctionPass::getAnalysisUsage(AU);
}
Expand Down Expand Up @@ -494,6 +617,8 @@ bool AMDGPUPostLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) {
MachineDominatorTree *MDT =
IsOptNone ? nullptr
: &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
MachineLoopInfo *MLI =
IsOptNone ? nullptr : &getAnalysis<MachineLoopInfoWrapperPass>().getLI();

CombinerInfo CInfo(/*AllowIllegalOps*/ false, /*ShouldLegalizeIllegal*/ true,
LI, EnableOpt, F.hasOptSize(), F.hasMinSize());
Expand All @@ -503,7 +628,7 @@ bool AMDGPUPostLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) {
// Legalizer performs DCE, so a full DCE pass is unnecessary.
CInfo.EnableFullDCE = false;
AMDGPUPostLegalizerCombinerImpl Impl(MF, CInfo, TPC, *VT, /*CSEInfo*/ nullptr,
RuleConfig, ST, MDT, LI);
RuleConfig, ST, MDT, MLI, LI);
return Impl.combineMachineInstrs();
}

Expand All @@ -513,6 +638,7 @@ INITIALIZE_PASS_BEGIN(AMDGPUPostLegalizerCombiner, DEBUG_TYPE,
false)
INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
INITIALIZE_PASS_DEPENDENCY(GISelValueTrackingAnalysisLegacy)
INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass)
INITIALIZE_PASS_END(AMDGPUPostLegalizerCombiner, DEBUG_TYPE,
"Combine AMDGPU machine instrs after legalization", false,
false)
Expand Down
Loading