@@ -1122,48 +1122,134 @@ bool VPlanTransforms::handleMaxMinNumReductions(VPlan &Plan) {
11221122 return true ;
11231123}
11241124
1125- // / Try to convert FindLastIV to FindFirstIV reduction when using a strict
1126- // / predicate. Returns the new FindFirstIVPhiR on success, nullptr on failure.
1127- static VPReductionPHIRecipe *
1128- tryConvertToFindFirstIV (VPlan &Plan, VPReductionPHIRecipe *FindLastIVPhiR,
1129- VPValue *IVOp, ScalarEvolution &SE, const Loop *L) {
1130- Type *Ty = VPTypeAnalysis (Plan).inferScalarType (FindLastIVPhiR);
1131- unsigned NumBits = Ty->getIntegerBitWidth ();
1132-
1133- // Determine the reduction kind and sentinel based on the IV range.
1134- RecurKind NewKind;
1135- VPValue *NewSentinel;
1136- auto *AR = cast<SCEVAddRecExpr>(vputils::getSCEVExprForVPValue (IVOp, SE, L));
1137- if (RecurrenceDescriptor::isValidIVRangeForFindIV (
1138- AR, /* IsSigned=*/ true , /* IsFindFirstIV=*/ true , SE)) {
1139- NewKind = RecurKind::FindFirstIVSMin;
1140- NewSentinel = Plan.getConstantInt (APInt::getSignedMaxValue (NumBits));
1141- } else if (RecurrenceDescriptor::isValidIVRangeForFindIV (
1142- AR, /* IsSigned=*/ false , /* IsFindFirstIV=*/ true , SE)) {
1143- NewKind = RecurKind::FindFirstIVUMin;
1144- NewSentinel = Plan.getConstantInt (APInt::getMaxValue (NumBits));
1145- } else {
1146- return nullptr ;
1125+ // / For argmin/argmax reductions with strict predicates, convert the existing
1126+ // / FindLastIV reduction to a new UMin reduction of a wide canonical IV. If the
1127+ // / original IV was not canonical, a new canonical wide IV is added, and the
1128+ // / final result is scaled back to the original IV.
1129+ static bool handleStrictArgMinArgMax (VPlan &Plan,
1130+ VPReductionPHIRecipe *MinMaxPhiR,
1131+ VPReductionPHIRecipe *FindIVPhiR,
1132+ VPWidenIntOrFpInductionRecipe *WideIV,
1133+ VPInstruction *MinMaxResult) {
1134+ Type *Ty = Plan.getVectorLoopRegion ()->getCanonicalIVType ();
1135+ if (Ty != VPTypeAnalysis (Plan).inferScalarType (FindIVPhiR))
1136+ return false ;
1137+
1138+ // If the original wide IV is not canonical, create a new one. The wide IV is
1139+ // guaranteed to not wrap for all lanes that are active in the vector loop.
1140+ if (!WideIV->isCanonical ()) {
1141+ VPValue *Zero = Plan.getOrAddLiveIn (ConstantInt::get (Ty, 0 ));
1142+ VPValue *One = Plan.getOrAddLiveIn (ConstantInt::get (Ty, 1 ));
1143+ auto *WidenCanIV = new VPWidenIntOrFpInductionRecipe (
1144+ nullptr , Zero, One, WideIV->getVFValue (),
1145+ WideIV->getInductionDescriptor (), VPIRFlags (), WideIV->getDebugLoc ());
1146+ WidenCanIV->insertBefore (WideIV);
1147+
1148+ // Update the select to use the wide canonical IV.
1149+ auto *SelectRecipe = cast<VPSingleDefRecipe>(
1150+ FindIVPhiR->getBackedgeValue ()->getDefiningRecipe ());
1151+ if (SelectRecipe->getOperand (1 ) == WideIV)
1152+ SelectRecipe->setOperand (1 , WidenCanIV);
1153+ else if (SelectRecipe->getOperand (2 ) == WideIV)
1154+ SelectRecipe->setOperand (2 , WidenCanIV);
11471155 }
11481156
1149- // Create the new FindFirstIV reduction recipe.
1150- assert (!FindLastIVPhiR->isInLoop () && !FindLastIVPhiR->isOrdered ());
1151- ReductionStyle Style = RdxUnordered{FindLastIVPhiR->getVFScaleFactor ()};
1152- auto *FindFirstIVPhiR =
1153- new VPReductionPHIRecipe (nullptr , NewKind, *NewSentinel, Style,
1154- FindLastIVPhiR->hasUsesOutsideReductionChain ());
1155- FindFirstIVPhiR->addOperand (FindLastIVPhiR->getBackedgeValue ());
1157+ // Create the new UMin reduction recipe to track the minimum index.
1158+ assert (!FindIVPhiR->isInLoop () && !FindIVPhiR->isOrdered () &&
1159+ " inloop and ordered reductions not supported" );
1160+ VPValue *MaxInt =
1161+ Plan.getConstantInt (APInt::getMaxValue (Ty->getIntegerBitWidth ()));
1162+ ReductionStyle Style = RdxUnordered{FindIVPhiR->getVFScaleFactor ()};
1163+ auto *MinIdxPhiR = new VPReductionPHIRecipe (
1164+ dyn_cast_or_null<PHINode>(FindIVPhiR->getUnderlyingValue ()),
1165+ RecurKind::UMin, *MaxInt, *FindIVPhiR->getBackedgeValue (), Style,
1166+ FindIVPhiR->hasUsesOutsideReductionChain ());
1167+ MinIdxPhiR->insertBefore (FindIVPhiR);
11561168
1157- FindFirstIVPhiR->insertBefore (FindLastIVPhiR);
11581169 VPInstruction *FindLastIVResult =
1159- findUserOf<VPInstruction::ComputeFindIVResult>(FindLastIVPhiR);
1160- FindLastIVPhiR->replaceAllUsesWith (FindFirstIVPhiR);
1161- FindLastIVResult->setOperand (2 , NewSentinel);
1162- return FindFirstIVPhiR;
1170+ findUserOf<VPInstruction::ComputeFindIVResult>(FindIVPhiR);
1171+ MinMaxResult->moveBefore (*FindLastIVResult->getParent (),
1172+ FindLastIVResult->getIterator ());
1173+
1174+ // The reduction using MinMaxPhiR needs adjusting to compute the correct
1175+ // result:
1176+ // 1. We need to find the first canonical IV for which the condition based
1177+ // on the min/max recurrence is true,
1178+ // 2. Compare the partial min/max reduction result to its final value and,
1179+ // 3. Select the lanes of the partial UMin reduction of the canonical wide
1180+ // IV which correspond to the lanes matching the min/max reduction result.
1181+ // 4. Scale the final select canonical IV back to the original IV using
1182+ // VPDerivedIVRecipe.
1183+ // 5. If the minimum value matches the start value, the condition in the
1184+ // loop was never true, return the start value in that case.
1185+ //
1186+ // The original reductions need adjusting:
1187+ // For example, this transforms
1188+ // vp<%min.result> = compute-reduction-result ir<%min.val>,
1189+ // ir<%min.val.next>
1190+ // vp<%find.iv.result = compute-find-iv-result ir<%min.idx>, ir<0>,
1191+ // SENTINEL, vp<%min.idx.next>
1192+ //
1193+ // into:
1194+ // vp<%min.result> = compute-reduction-result ir<%min.val>, ir<%min.val.next>
1195+ // vp<%final.min.cmp> = icmp eq ir<%min.val.next>, vp<%min.result>
1196+ // vp<%final.min.iv> = select vp<%final.min.cmp>, ir<%min.idx.next>, ir<-1>
1197+ // vp<%13> = compute-reduction-result ir<%min.idx>, vp<%final.min.iv>
1198+ // vp<%scaled.result.iv> = DERIVED-IV ir<20> + vp<%13> * ir<1>
1199+ // vp<%threshold.cmp> = icmp slt vp<%min.result>, ir<0>
1200+ // vp<%final.result> = select vp<%threshold.cmp>, vp<%scaled.result.iv>,
1201+ // ir<%original.start>
1202+
1203+ VPBuilder Builder (FindLastIVResult);
1204+ VPValue *MinMaxExiting = MinMaxResult->getOperand (1 );
1205+ auto *FinalMinMaxCmp =
1206+ Builder.createICmp (CmpInst::ICMP_EQ, MinMaxExiting, MinMaxResult);
1207+ VPValue *LastIVExiting = FindLastIVResult->getOperand (3 );
1208+ auto *FinalIVSelect =
1209+ Builder.createSelect (FinalMinMaxCmp, LastIVExiting, MaxInt);
1210+ VPSingleDefRecipe *FinalResult = Builder.createNaryOp (
1211+ VPInstruction::ComputeReductionResult, {MinIdxPhiR, FinalIVSelect}, {},
1212+ FindLastIVResult->getDebugLoc ());
1213+
1214+ // If we used a new wide canonical IV convert the reduction result back to the
1215+ // original IV scale before the final select.
1216+ if (!WideIV->isCanonical ()) {
1217+ auto *DerivedIVRecipe =
1218+ new VPDerivedIVRecipe (InductionDescriptor::IK_IntInduction,
1219+ nullptr , // No FPBinOp for integer induction
1220+ WideIV->getStartValue (), FinalResult,
1221+ WideIV->getStepValue (), " derived.iv.result" );
1222+ DerivedIVRecipe->insertBefore (&*Builder.getInsertPoint ());
1223+ FinalResult = DerivedIVRecipe;
1224+ }
1225+
1226+ auto GetPred = [&MinMaxPhiR]() {
1227+ switch (MinMaxPhiR->getRecurrenceKind ()) {
1228+ case RecurKind::UMin:
1229+ return CmpInst::ICMP_ULT;
1230+ case RecurKind::SMin:
1231+ return CmpInst::ICMP_SLT;
1232+ case RecurKind::UMax:
1233+ return CmpInst::ICMP_UGT;
1234+ case RecurKind::SMax:
1235+ return CmpInst::ICMP_SGT;
1236+ default :
1237+ llvm_unreachable (" must be an integer min/max recurrence kind" );
1238+ }
1239+ };
1240+ // If the final min/max value matches the start value, the condition in the
1241+ // loop was always false, i.e. no induction value has been selected. If that's
1242+ // the case, use the original start value.
1243+ VPValue *MinMaxLT =
1244+ Builder.createICmp (GetPred (), MinMaxResult, MinMaxPhiR->getStartValue ());
1245+ VPValue *Res = Builder.createSelect (MinMaxLT, FinalResult,
1246+ FindLastIVResult->getOperand (1 ));
1247+ FindIVPhiR->replaceAllUsesWith (MinIdxPhiR);
1248+ FindLastIVResult->replaceAllUsesWith (Res);
1249+ return true ;
11631250}
11641251
1165- bool VPlanTransforms::handleMultiUseReductions (VPlan &Plan, ScalarEvolution &SE,
1166- const Loop *L) {
1252+ bool VPlanTransforms::handleMultiUseReductions (VPlan &Plan) {
11671253 for (auto &PhiR : make_early_inc_range (
11681254 Plan.getVectorLoopRegion ()->getEntryBasicBlock ()->phis ())) {
11691255 auto *MinMaxPhiR = dyn_cast<VPReductionPHIRecipe>(&PhiR);
@@ -1174,7 +1260,7 @@ bool VPlanTransforms::handleMultiUseReductions(VPlan &Plan, ScalarEvolution &SE,
11741260 // MinMaxPhiR has users outside the reduction cycle in the loop. Check if
11751261 // the only other user is a FindLastIV reduction. MinMaxPhiR must have
11761262 // exactly 3 users: 1) the min/max operation, the compare of a FindLastIV
1177- // reduction and ComputeReductionResult. The comparisom must compare
1263+ // reduction and ComputeReductionResult. The comparison must compare
11781264 // MinMaxPhiR against the min/max operand used for the min/max reduction
11791265 // and only be used by the select of the FindLastIV reduction.
11801266 RecurKind RdxKind = MinMaxPhiR->getRecurrenceKind ();
@@ -1273,13 +1359,14 @@ bool VPlanTransforms::handleMultiUseReductions(VPlan &Plan, ScalarEvolution &SE,
12731359 if (!IsValidPredicate)
12741360 return false ;
12751361
1276- // For strict predicates, transform try to convert FindLastIV to
1277- // FindFirstIV.
1362+ // For strict predicates, use a UMin reduction to find the minimum index.
1363+ // Canonical IVs (0, 1, 2, ...) are guaranteed not to wrap in the vector
1364+ // loop, so UMin can always be used.
12781365 bool IsStrictPredicate = ICmpInst::isLT (Pred) || ICmpInst::isGT (Pred);
12791366 if (IsStrictPredicate) {
1280- FindIVPhiR = tryConvertToFindFirstIV (Plan, FindIVPhiR, IVOp, SE, L);
1281- if (!FindIVPhiR)
1282- return false ;
1367+ return handleStrictArgMinArgMax (Plan, MinMaxPhiR, FindIVPhiR,
1368+ cast<VPWidenIntOrFpInductionRecipe>(IVOp),
1369+ MinMaxResult) ;
12831370 }
12841371
12851372 // The reduction using MinMaxPhiR needs adjusting to compute the correct
0 commit comments