From c7a4e051cc3e799c859f2e134752d26084cc4bbd Mon Sep 17 00:00:00 2001 From: Abdalrhman Mohamed Date: Wed, 21 Jan 2026 09:06:00 -0800 Subject: [PATCH 1/4] feat: refactor addition and subtraction to use unpacked floats --- Fp/Addition.lean | 163 ++++++++++++++++++++++---------------------- Fp/Basic.lean | 2 +- Fp/Negation.lean | 45 ++++-------- Fp/Subtraction.lean | 78 ++++++++++++++------- 4 files changed, 146 insertions(+), 142 deletions(-) diff --git a/Fp/Addition.lean b/Fp/Addition.lean index 17d2d7e..3955913 100644 --- a/Fp/Addition.lean +++ b/Fp/Addition.lean @@ -1,92 +1,91 @@ import Fp.Basic -import Fp.Rounding +import Fp.UnpackedRound -/-- -Addition of two fixed-point numbers. - -When the sum is zero, the sign of the zero is dependent on the provided -rounding mode. --/ -@[bv_normalize] -def f_add (mode : RoundingMode) (a b : FixedPoint w e) : FixedPoint (w+1) e := - let hExOffset : e < w+1 := by - exact Nat.lt_add_right 1 a.hExOffset - let ax := BitVec.setWidth' (by omega) a.val - let bx := BitVec.setWidth' (by omega) b.val - if a.sign == b.sign then - -- Addition of same-signed numbers always preserves sign - { - sign := a.sign - val := BitVec.add ax bx - hExOffset := hExOffset - } - else if BitVec.ult ax bx then - { - sign := b.sign - val := BitVec.sub bx ax - hExOffset := hExOffset - } - else if BitVec.ult bx ax then +def UnpackedFloat.add (sign : Bool) (x y : UnpackedFloat e s) : UnpackedFloat (e + 1) (s + 2) := + -- Compute absolute exponent difference to determine significant shift amount. + let expDiff : BitVec (e + 1) := x.ex.signExtend (e + 1) - y.ex.signExtend (e + 1) + let absExpDiff := bif expDiff.msb then -expDiff else expDiff + -- Determine the smaller number whose significant we are going to shift. + let (x, y) := bif expDiff.msb || absExpDiff == 0 && x.sig.ult y.sig then (y, x) else (x, y) + -- Extend by 1 bit to the left to account for overflow and two bits to the right to account for + -- round and sticky bits. + let xSig := x.sig.setWidth' (by omega) ++ 0#2 + let ySig := y.sig.setWidth' (by omega) ++ 0#2 + -- Reuse the same circuit for both addition and subtraction by negating the smaller significant. + -- Note: we always treat significants as unsigned integers. However, we make an exception here + -- to reuse the same adder circuit for both addition and subtraction. + let ySig := bif x.sign == y.sign then ySig else -ySig + -- Cap the right shift at `s + 3` in case `absExpDiff` is too big. + let shiftAmount := bif absExpDiff.ult (s + 3) then absExpDiff.setWidth (s + 3) else (s + 3) + -- Note: we use signed shift for `ySig` here to preserve its sign since it's now a signed integer. + let sigSum : BitVec (s + 3) := xSig + ySig.sshiftRight' shiftAmount + -- Sticky bit depends on bits we lose when we right shift `ySig` and `sigSum` (in case of an overflow). + let sticky := ySig &&& shiftAmount.orderEncode != 0 || sigSum.msb && sigSum.getLsb 0 + let sumResult := { - sign := a.sign - val := BitVec.sub ax bx - hExOffset := hExOffset + -- Sign of sum is sign of the bigger number! + sign := x.sign + -- Exponent of sum is exponent of bigger number (`+1` if there is an overflow). + ex := x.ex.signExtend (e + 1) + (BitVec.ofBool sigSum.msb).setWidth' (by omega) + -- Renormalize `sigSum` if there is an overflow. + sig := (sigSum >>> BitVec.ofBool sigSum.msb ||| (BitVec.ofBool sticky).setWidth' (by omega)).truncate (s + 2) } + bif sigSum == 0 then + -- Full cancellation: return zero. This case could have been merged with the second branch if not + -- for the sign, which depends on the rounding mode. + .mkZero sign + else bif !sigSum.getMsb 0 && !sigSum.getMsb 1 then + -- Catastrophic cancellation: we have to normalize. + sumResult.normalize else - -- Signs are different but values are same, so return +0.0 - -- When rounding mode is RTN we should instead return -0.0 - { - sign := mode = .RTN - val := 0#_ - hExOffset := hExOffset - } - -/-- -Addition of two extended fixed-point numbers. + sumResult -When the sum is zero, the sign of the zero is dependent on the provided -rounding mode. --/ -@[bv_normalize] -def e_add (mode : RoundingMode) (a b : EFixedPoint w e) : EFixedPoint (w+1) e := - open EFixedPoint in - let hExOffset : e < w + 1 := by - exact Nat.lt_add_right 1 a.num.hExOffset - -- As of 2025-04-14, bv_decide does not support pattern matches on more than - -- one variable, so we'll have to deal with if-statements for now - if hN : a.state = .NaN || b.state = .NaN then getNaN hExOffset - else if hI1 : a.state = .Infinity && b.state = .Infinity then - if a.num.sign == b.num.sign then getInfinity a.num.sign hExOffset - else getNaN hExOffset - else if hI2 : a.state = .Infinity then getInfinity a.num.sign hExOffset - else if hI3 : b.state = .Infinity then getInfinity b.num.sign hExOffset +def EUnpackedFloat.add (m : RoundingMode) (x y : EUnpackedFloat (exponentWidth e s) (s + 1)) + : EUnpackedFloat (exponentWidth e s) (s + 1) := + bif x.isZero && !y.isZero then + y + else bif !x.isZero && y.isZero then + x + else bif x.isNaN || y.isNaN || x.isInfinite && y.isInfinite && x.sign != y.sign then + .mkNaN + else bif x.isInfinite && y.isInfinite && x.sign == y.sign || + x.isInfinite && !y.isInfinite || !x.isInfinite && y.isInfinite then + .mkInfinity (bif x.isInfinite then x.sign else y.sign) + else bif x.isZero && y.isZero then + .mkZero (bif m == .RTN then x.sign || y.sign else x.sign && y.sign) else - -- is this how to do assertions? - let _ : a.state = .Number && b.state = .Number := by - cases ha : a.state <;> cases hb : b.state <;> simp_all - { - state := .Number - num := f_add mode a.num b.num - } + UnpackedFloat.round (.add (m == .RTN) x.num y.num) m + +namespace PackedFloat + +def add (m : RoundingMode) (x y : PackedFloat e s) : PackedFloat e s := + (EUnpackedFloat.add m x.unpack y.unpack).pack + +instance : Add (PackedFloat e s) where + add := .add .RNE + +end PackedFloat + +/-- info: ExtRat.Number (-1 : Rat)/16384 -/ +#guard_msgs in #eval (PackedFloat.ofBits 5 2 0b10000100#8).toExtRat +/-- info: ExtRat.Number (5 : Rat)/8192 -/ +#guard_msgs in #eval (PackedFloat.ofBits 5 2 0b00010001#8).toExtRat + +/-- info: 9 / 16384 -/ +#guard_msgs in #eval (-1 : Rat)/16384 + (5 : Rat)/8192 + +/-- info: ExtRat.Number (1 : Rat)/2048 -/ +#guard_msgs in #eval (PackedFloat.ofRat 5 2 .RNE 9 16384).toExtRat -/-- -Addition of two floating point numbers, rounded to a floating point number -using the provided rounding mode. --/ -@[bv_normalize] -def add (a b : PackedFloat e s) (mode : RoundingMode) : PackedFloat e s := - EFixedPoint.round _ _ mode (e_add mode a.toEFixed b.toEFixed) +/-- info: ExtRat.Number (5 : Rat)/8192 -/ +#guard_msgs in #eval (PackedFloat.ofRat 5 2 .RNA 9 16384).toExtRat --- Proof by brute force (it takes a while) -/- -theorem PackedFloat_add_comm' (m : RoundingMode) (a b : PackedFloat 5 2) - : (add a b m) = (add b a m) := by - apply PackedFloat.inj - simp [add, e_add, f_add, round, PackedFloat.toEFixed, -BitVec.shiftLeft_eq', -BitVec.ushiftRight_eq'] - bv_decide --/ +/-- info: ExtRat.Number (1 : Rat)/2048 -/ +#guard_msgs in #eval (PackedFloat.add .RNE (PackedFloat.ofBits 5 2 0b10000100#8) (PackedFloat.ofBits 5 2 0b00010001#8)).toExtRat +/-- info: ExtRat.Number (5 : Rat)/8192 -/ +#guard_msgs in #eval (PackedFloat.add .RNA (PackedFloat.ofBits 5 2 0b10000100#8) (PackedFloat.ofBits 5 2 0b00010001#8)).toExtRat -/-- info: { sign := +, ex := 0x04#5, sig := 0x0#2 } -/ -#guard_msgs in #eval add (PackedFloat.ofBits 5 2 0b10000100#8) (PackedFloat.ofBits 5 2 0b00010001#8) .RNE -/-- info: { sign := +, ex := 0x04#5, sig := 0x1#2 } -/ -#guard_msgs in #eval add (PackedFloat.ofBits 5 2 0b10000100#8) (PackedFloat.ofBits 5 2 0b00010001#8) .RNA +/-- info: ExtRat.Infinity true -/ +#guard_msgs in #eval (PackedFloat.add .RNE (PackedFloat.ofBits 5 2 0b11111100#8) (PackedFloat.ofBits 5 2 0b00010001#8)).toExtRat +/-- info: ExtRat.Infinity false -/ +#guard_msgs in #eval (PackedFloat.add .RNE (PackedFloat.ofBits 5 2 0b01111100#8) (PackedFloat.ofBits 5 2 0b00010001#8)).toExtRat diff --git a/Fp/Basic.lean b/Fp/Basic.lean index c86a198..f70569c 100644 --- a/Fp/Basic.lean +++ b/Fp/Basic.lean @@ -963,7 +963,7 @@ def isZero (uf : UnpackedFloat e s) : Bool := @[bv_normalize] def normalize (uf : UnpackedFloat e s) : UnpackedFloat e s := - bif uf.sig.clz == s then + bif uf.sig == 0 then -- zero case: make it explicit! mkZero uf.sign else diff --git a/Fp/Negation.lean b/Fp/Negation.lean index f8f4b74..8665622 100644 --- a/Fp/Negation.lean +++ b/Fp/Negation.lean @@ -1,40 +1,19 @@ import Fp.Basic -import Fp.Rounding +import Fp.Packing -/-- Negate the fixed point number -/ -@[bv_normalize] -def f_neg (a : FixedPoint w e) : FixedPoint w e := { a with sign := !a.sign } +def UnpackedFloat.neg (x : UnpackedFloat e s) : UnpackedFloat e s := + { x with sign := !x.sign } -/-- Negate the extended fixed point number -/ -@[bv_normalize] -def e_neg (a : EFixedPoint w e) : EFixedPoint w e := - open EFixedPoint in - have := a.num.hExOffset - if hN : a.state = .NaN then - getNaN (by omega) - else if hInf : a.state = .Infinity then - getInfinity (!a.num.sign) (by omega) - else - let _ : a.state = .Number := by - cases h : a.state <;> simp_all - { a with num := f_neg a.num } +def EUnpackedFloat.neg (x : EUnpackedFloat (exponentWidth e s) (s + 1)) + : EUnpackedFloat (exponentWidth e s) (s + 1) := + .mkNumber x.num.neg -/-- Negate a floating-point number, by conversion to a fixed-point number. -/ -@[bv_normalize] -def negfixed (a : PackedFloat e s) (mode : RoundingMode) : PackedFloat e s := - EFixedPoint.round _ _ mode (e_neg a.toEFixed) +namespace PackedFloat -/-- -Negate a floating-point number, by flipping the sign bit. +def neg (x : PackedFloat e s) : PackedFloat e s := + x.unpack.neg.pack -This implements the same function as `negfixed`, but is much simpler. --/ -@[bv_normalize] -def neg (a : PackedFloat e s) : PackedFloat e s := - if a.isNaN then PackedFloat.getNaN _ _ - else { a with sign := !a.sign } +instance : Neg (PackedFloat e s) where + neg := .neg -@[bv_normalize] -def abs (a : PackedFloat e s) : PackedFloat e s := - if a.isNaN then PackedFloat.getNaN _ _ - else { a with sign := false } +end PackedFloat diff --git a/Fp/Subtraction.lean b/Fp/Subtraction.lean index 28c1942..481267a 100644 --- a/Fp/Subtraction.lean +++ b/Fp/Subtraction.lean @@ -1,29 +1,55 @@ -import Fp.Basic -import Fp.Rounding import Fp.Addition import Fp.Negation -/-- -Subtraction of two extended fixed-point numbers. --/ -@[bv_normalize] -def e_sub (mode : RoundingMode) (a b : EFixedPoint w e) : EFixedPoint (w+1) e := - e_add mode a (e_neg b) - -/-- -Subtraction of two floating-point numbers. - -Implemented entirely within EFixedPoint using `e_sub`. --/ -@[bv_normalize] -def subfixed (a b : PackedFloat e s) (mode : RoundingMode) : PackedFloat e s := - EFixedPoint.round _ _ mode (e_sub mode a.toEFixed b.toEFixed) - -/-- -Subtraction of two floating-point numbers. - -Implemented as a negation followed by an addition. --/ -@[bv_normalize] -def sub (a b : PackedFloat e s) (mode : RoundingMode) : PackedFloat e s := - add a (neg b) mode +def UnpackedFloat.sub (sign : Bool) (x y : UnpackedFloat e s) : UnpackedFloat (e + 1) (s + 2) := + .add sign x y.neg + +def EUnpackedFloat.sub (m : RoundingMode) (x y : EUnpackedFloat (exponentWidth e s) (s + 1)) + : EUnpackedFloat (exponentWidth e s) (s + 1) := + bif x.isZero && !y.isZero then + y.neg + else bif !x.isZero && y.isZero then + x + else bif x.isNaN || y.isNaN || x.isInfinite && y.isInfinite && x.sign == y.sign then + .mkNaN + else bif x.isInfinite && y.isInfinite && x.sign != y.sign || + x.isInfinite && !y.isInfinite || !x.isInfinite && y.isInfinite then + .mkInfinity (bif x.isInfinite then x.sign else !y.sign) + else bif x.isZero && y.isZero then + .mkZero (bif m == .RTN then x.sign || !y.sign else x.sign && !y.sign) + else + UnpackedFloat.round (.sub (m == .RTN) x.num y.num) m + +namespace PackedFloat + +def sub (m : RoundingMode) (x y : PackedFloat e s) : PackedFloat e s := + (EUnpackedFloat.sub m x.unpack y.unpack).pack + +instance : Sub (PackedFloat e s) where + sub := .sub .RNE + +end PackedFloat + +/-- info: ExtRat.Number (-1 : Rat)/16384 -/ +#guard_msgs in #eval (PackedFloat.ofBits 5 2 0b10000100#8).toExtRat +/-- info: ExtRat.Number (5 : Rat)/8192 -/ +#guard_msgs in #eval (PackedFloat.ofBits 5 2 0b00010001#8).toExtRat + +/-- info: -11 / 16384 -/ +#guard_msgs in #eval (-1 : Rat)/16384 - (5 : Rat)/8192 + +/-- info: ExtRat.Number (-3 : Rat)/4096 -/ +#guard_msgs in #eval (PackedFloat.ofRat 5 2 .RNE (-11) 16384).toExtRat + +/-- info: ExtRat.Number (-3 : Rat)/4096 -/ +#guard_msgs in #eval (PackedFloat.ofRat 5 2 .RNA (-11) 16384).toExtRat + +/-- info: ExtRat.Number (-3 : Rat)/4096 -/ +#guard_msgs in #eval (PackedFloat.sub .RNE (PackedFloat.ofBits 5 2 0b10000100#8) (PackedFloat.ofBits 5 2 0b00010001#8)).toExtRat +/-- info: ExtRat.Number (-3 : Rat)/4096 -/ +#guard_msgs in #eval (PackedFloat.sub .RNA (PackedFloat.ofBits 5 2 0b10000100#8) (PackedFloat.ofBits 5 2 0b00010001#8)).toExtRat + +/-- info: ExtRat.Infinity true -/ +#guard_msgs in #eval (PackedFloat.sub .RNE (PackedFloat.ofBits 5 2 0b11111100#8) (PackedFloat.ofBits 5 2 0b00010001#8)).toExtRat +/-- info: ExtRat.Infinity false -/ +#guard_msgs in #eval (PackedFloat.sub .RNE (PackedFloat.ofBits 5 2 0b01111100#8) (PackedFloat.ofBits 5 2 0b00010001#8)).toExtRat From b6b70c34028c00a443ec0cfd940eef2fff804575 Mon Sep 17 00:00:00 2001 From: Abdalrhman Mohamed Date: Wed, 21 Jan 2026 11:14:03 -0800 Subject: [PATCH 2/4] chore: refactor abs to use unpacked floats --- Fp/Negation.lean | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/Fp/Negation.lean b/Fp/Negation.lean index 8665622..73eca29 100644 --- a/Fp/Negation.lean +++ b/Fp/Negation.lean @@ -4,16 +4,58 @@ import Fp.Packing def UnpackedFloat.neg (x : UnpackedFloat e s) : UnpackedFloat e s := { x with sign := !x.sign } +def UnpackedFloat.abs (x : UnpackedFloat e s) : UnpackedFloat e s := + { x with sign := false } + def EUnpackedFloat.neg (x : EUnpackedFloat (exponentWidth e s) (s + 1)) : EUnpackedFloat (exponentWidth e s) (s + 1) := .mkNumber x.num.neg +def EUnpackedFloat.abs (x : EUnpackedFloat (exponentWidth e s) (s + 1)) + : EUnpackedFloat (exponentWidth e s) (s + 1) := + .mkNumber x.num.abs + namespace PackedFloat def neg (x : PackedFloat e s) : PackedFloat e s := + -- At first glance, unpacking a float just to modify its sign might look like + -- unnecessary work; after all, you *can* toggle the sign bit directly on the + -- packed representation. And if this operation lived in isolation, that would + -- indeed be the simpler route. + -- + -- However, most floating‑point operations already require unpacking, and once + -- you're working in that representation, it's usually best to stay there. + -- Thanks to the identity + -- + -- `∀ uf m, (uf.round m).pack.unpack = uf.round m` + -- + -- we can keep everything in the unpacked form throughout a chain of + -- operations and only repack at the very end. Unpacking here helps maintain + -- that workflow and avoids bouncing back and forth between representations. + -- + -- TODO: is there a way to hint this optimization strategy to the compiler? x.unpack.neg.pack instance : Neg (PackedFloat e s) where neg := .neg +def abs (x : PackedFloat e s) : PackedFloat e s := + -- At first glance, unpacking a float just to modify its sign might look like + -- unnecessary work; after all, you *can* toggle the sign bit directly on the + -- packed representation. And if this operation lived in isolation, that would + -- indeed be the simpler route. + -- + -- However, most floating‑point operations already require unpacking, and once + -- you're working in that representation, it's usually best to stay there. + -- Thanks to the identity + -- + -- `∀ uf m, (uf.round m).pack.unpack = uf.round m` + -- + -- we can keep everything in the unpacked form throughout a chain of + -- operations and only repack at the very end. Unpacking here helps maintain + -- that workflow and avoids bouncing back and forth between representations. + -- + -- TODO: is there a way to hint this optimization strategy to the compiler? + x.unpack.abs.pack + end PackedFloat From 9ff706911481bae01956840eb7d900dac84d3c33 Mon Sep 17 00:00:00 2001 From: Abdalrhman Mohamed Date: Thu, 22 Jan 2026 06:11:14 -0800 Subject: [PATCH 3/4] chore: resolve build errors --- Fp/FMA.lean | 68 +++++++++++++++++++++++++++++++++++++++++ Fp/Proofs/Addition.lean | 2 +- Main.lean | 10 +++--- 3 files changed, 74 insertions(+), 6 deletions(-) diff --git a/Fp/FMA.lean b/Fp/FMA.lean index ce96203..7fc9ecf 100644 --- a/Fp/FMA.lean +++ b/Fp/FMA.lean @@ -3,6 +3,74 @@ import Fp.Rounding import Fp.Addition import Fp.Multiplication +/-- +Addition of two fixed-point numbers. + +When the sum is zero, the sign of the zero is dependent on the provided +rounding mode. +-/ +@[bv_normalize] +def f_add (mode : RoundingMode) (a b : FixedPoint w e) : FixedPoint (w+1) e := + let hExOffset : e < w+1 := by + exact Nat.lt_add_right 1 a.hExOffset + let ax := BitVec.setWidth' (by omega) a.val + let bx := BitVec.setWidth' (by omega) b.val + if a.sign == b.sign then + -- Addition of same-signed numbers always preserves sign + { + sign := a.sign + val := BitVec.add ax bx + hExOffset := hExOffset + } + else if BitVec.ult ax bx then + { + sign := b.sign + val := BitVec.sub bx ax + hExOffset := hExOffset + } + else if BitVec.ult bx ax then + { + sign := a.sign + val := BitVec.sub ax bx + hExOffset := hExOffset + } + else + -- Signs are different but values are same, so return +0.0 + -- When rounding mode is RTN we should instead return -0.0 + { + sign := mode = .RTN + val := 0#_ + hExOffset := hExOffset + } + +/-- +Addition of two extended fixed-point numbers. + +When the sum is zero, the sign of the zero is dependent on the provided +rounding mode. +-/ +@[bv_normalize] +def e_add (mode : RoundingMode) (a b : EFixedPoint w e) : EFixedPoint (w+1) e := + open EFixedPoint in + let hExOffset : e < w + 1 := by + exact Nat.lt_add_right 1 a.num.hExOffset + -- As of 2025-04-14, bv_decide does not support pattern matches on more than + -- one variable, so we'll have to deal with if-statements for now + if hN : a.state = .NaN || b.state = .NaN then getNaN hExOffset + else if hI1 : a.state = .Infinity && b.state = .Infinity then + if a.num.sign == b.num.sign then getInfinity a.num.sign hExOffset + else getNaN hExOffset + else if hI2 : a.state = .Infinity then getInfinity a.num.sign hExOffset + else if hI3 : b.state = .Infinity then getInfinity b.num.sign hExOffset + else + -- is this how to do assertions? + let _ : a.state = .Number && b.state = .Number := by + cases ha : a.state <;> cases hb : b.state <;> simp_all + { + state := .Number + num := f_add mode a.num b.num + } + @[bv_normalize] def fma (a b c : PackedFloat e s) (m : RoundingMode) : PackedFloat e s := diff --git a/Fp/Proofs/Addition.lean b/Fp/Proofs/Addition.lean index d51c3e3..9a5c240 100644 --- a/Fp/Proofs/Addition.lean +++ b/Fp/Proofs/Addition.lean @@ -1,6 +1,6 @@ import Fp.Proofs.Basic import Init.Data.Dyadic -import Fp.Addition +import Fp.FMA import Fp.ForLean.Dyadic import Fp.ForLean.Rat diff --git a/Main.lean b/Main.lean index 8467831..e3322bf 100644 --- a/Main.lean +++ b/Main.lean @@ -33,7 +33,7 @@ def test_add (f : FP8Format) (m : RoundingMode) (a b : BitVec 8) : OpResult := { oper := "add" mode := m - result := [a, b, f.h.mp (add a' b' m).toBits].map toDigits + result := [a, b, f.h.mp (PackedFloat.add m a' b').toBits].map toDigits } def test_sub (f : FP8Format) (m : RoundingMode) (a b : BitVec 8) : OpResult := @@ -42,7 +42,7 @@ def test_sub (f : FP8Format) (m : RoundingMode) (a b : BitVec 8) : OpResult := { oper := "sub" mode := m - result := [a, b, f.h.mp (sub a' b' m).toBits].map toDigits + result := [a, b, f.h.mp (PackedFloat.sub m a' b').toBits].map toDigits } def test_div (f : FP8Format) (m : RoundingMode) (a b : BitVec 8) : OpResult := @@ -95,7 +95,7 @@ def test_neg (f : FP8Format) (m : RoundingMode) (a : BitVec 8) : OpResult := { oper := "neg" mode := m - result := [a, 0#8, f.h.mp (PackedFloat.toBits (neg a'))].map toDigits + result := [a, 0#8, f.h.mp (PackedFloat.toBits a'.neg)].map toDigits } def test_abs (f : FP8Format) (m : RoundingMode) (a : BitVec 8) : OpResult := @@ -103,7 +103,7 @@ def test_abs (f : FP8Format) (m : RoundingMode) (a : BitVec 8) : OpResult := { oper := "abs" mode := m - result := [a, 0#8, f.h.mp (PackedFloat.toBits (abs a'))].map toDigits + result := [a, 0#8, f.h.mp (PackedFloat.toBits a'.abs)].map toDigits } def test_roundToInt (f : FP8Format) (m : RoundingMode) (a : BitVec 8) : OpResult := @@ -230,7 +230,7 @@ def main (args : List String) : IO Unit := do /-- info: { sign := -, ex := 0x04#5, sig := 0x1#2 } -/ -#guard_msgs in #eval add (PackedFloat.ofBits 5 2 0b00000011#8) (PackedFloat.ofBits 5 2 0b10010001#8) .RNE +#guard_msgs in #eval PackedFloat.add .RNE (PackedFloat.ofBits 5 2 0b00000011#8) (PackedFloat.ofBits 5 2 0b10010001#8) /-- info: { sign := +, ex := 0x01#5, sig := 0x2#2 } -/ #guard_msgs in #eval EFixedPoint.round 5 2 .RNE (PackedFloat.toEFixed {sign := false, ex := 1#5, sig := 2#2}) /-- info: { sign := +, ex := 0x1f#5, sig := 0x2#2 } -/ From a87a3be871372f938f1885bf3e44176448a3d0ee Mon Sep 17 00:00:00 2001 From: Abdalrhman Mohamed Date: Thu, 22 Jan 2026 11:13:45 -0800 Subject: [PATCH 4/4] fix: fix bugs in the refactored ops --- Fp/Addition.lean | 50 +++++++++++++++++++++++++++++++++--------------- Fp/Basic.lean | 9 ++++++--- Fp/Negation.lean | 4 ++-- Fp/Packing.lean | 7 +++---- Main.lean | 28 +++++++++++++++------------ 5 files changed, 62 insertions(+), 36 deletions(-) diff --git a/Fp/Addition.lean b/Fp/Addition.lean index 3955913..63ca72f 100644 --- a/Fp/Addition.lean +++ b/Fp/Addition.lean @@ -21,24 +21,20 @@ def UnpackedFloat.add (sign : Bool) (x y : UnpackedFloat e s) : UnpackedFloat (e let sigSum : BitVec (s + 3) := xSig + ySig.sshiftRight' shiftAmount -- Sticky bit depends on bits we lose when we right shift `ySig` and `sigSum` (in case of an overflow). let sticky := ySig &&& shiftAmount.orderEncode != 0 || sigSum.msb && sigSum.getLsb 0 - let sumResult := + let sum : UnpackedFloat (e + 1) (s + 2) := { -- Sign of sum is sign of the bigger number! sign := x.sign -- Exponent of sum is exponent of bigger number (`+1` if there is an overflow). ex := x.ex.signExtend (e + 1) + (BitVec.ofBool sigSum.msb).setWidth' (by omega) -- Renormalize `sigSum` if there is an overflow. - sig := (sigSum >>> BitVec.ofBool sigSum.msb ||| (BitVec.ofBool sticky).setWidth' (by omega)).truncate (s + 2) + sig := (sigSum >>> BitVec.ofBool sigSum.msb).truncate (s + 2) } - bif sigSum == 0 then - -- Full cancellation: return zero. This case could have been merged with the second branch if not - -- for the sign, which depends on the rounding mode. - .mkZero sign - else bif !sigSum.getMsb 0 && !sigSum.getMsb 1 then - -- Catastrophic cancellation: we have to normalize. - sumResult.normalize - else - sumResult + -- If a catastrophic cancellation occured, we have to normalize. In case the sum is `0` (i.e., full + -- cancellation), the sign depends on the rounding mode. + let normSum := bif !sum.sig.msb then sum.normalize sign else sum + -- Sticky bit is independent of normalization: add it at the very end. + { normSum with sig := normSum.sig ||| (BitVec.ofBool sticky).setWidth' (by omega) } def EUnpackedFloat.add (m : RoundingMode) (x y : EUnpackedFloat (exponentWidth e s) (s + 1)) : EUnpackedFloat (exponentWidth e s) (s + 1) := @@ -66,20 +62,44 @@ instance : Add (PackedFloat e s) where end PackedFloat +-- Minor cancellation with rounding + +/-- info: ExtRat.Number (5 : Rat)/64 -/ +#guard_msgs in #eval (PackedFloat.ofBits 3 4 0b00000101).toExtRat +/-- info: ExtRat.Number -2 -/ +#guard_msgs in #eval (PackedFloat.ofBits 3 4 0b11000000).toExtRat +/-- info: -123 / 64 -/ +#guard_msgs in #eval (5 : Rat)/64 + -2 +/-- info: ExtRat.Number (-31 : Rat)/16 -/ +#guard_msgs in #eval (PackedFloat.ofRat 3 4 .RNE (-123) 64).toExtRat +/-- info: ExtRat.Number (-31 : Rat)/16 -/ +#guard_msgs in #eval (PackedFloat.add .RNE (PackedFloat.ofBits 3 4 0b00000101) (PackedFloat.ofBits 3 4 0b11000000)).toExtRat + +-- Minor cancellation without rounding + +/-- info: ExtRat.Number (5 : Rat)/64 -/ +#guard_msgs in #eval (PackedFloat.ofBits 3 4 0b00000101).toExtRat +/-- info: ExtRat.Number -4 -/ +#guard_msgs in #eval (PackedFloat.ofBits 3 4 0b11010000).toExtRat +/-- info: -251 / 64 -/ +#guard_msgs in #eval (5 : Rat)/64 + -4 +/-- info: ExtRat.Number (-31 : Rat)/8 -/ +#guard_msgs in #eval (PackedFloat.ofRat 3 4 .RNE (-251) 64).toExtRat +/-- info: ExtRat.Number (-31 : Rat)/8 -/ +#guard_msgs in #eval (PackedFloat.add .RNE (PackedFloat.ofBits 3 4 0b00000101) (PackedFloat.ofBits 3 4 0b11010000)).toExtRat + +-- Rounding Modes + /-- info: ExtRat.Number (-1 : Rat)/16384 -/ #guard_msgs in #eval (PackedFloat.ofBits 5 2 0b10000100#8).toExtRat /-- info: ExtRat.Number (5 : Rat)/8192 -/ #guard_msgs in #eval (PackedFloat.ofBits 5 2 0b00010001#8).toExtRat - /-- info: 9 / 16384 -/ #guard_msgs in #eval (-1 : Rat)/16384 + (5 : Rat)/8192 - /-- info: ExtRat.Number (1 : Rat)/2048 -/ #guard_msgs in #eval (PackedFloat.ofRat 5 2 .RNE 9 16384).toExtRat - /-- info: ExtRat.Number (5 : Rat)/8192 -/ #guard_msgs in #eval (PackedFloat.ofRat 5 2 .RNA 9 16384).toExtRat - /-- info: ExtRat.Number (1 : Rat)/2048 -/ #guard_msgs in #eval (PackedFloat.add .RNE (PackedFloat.ofBits 5 2 0b10000100#8) (PackedFloat.ofBits 5 2 0b00010001#8)).toExtRat /-- info: ExtRat.Number (5 : Rat)/8192 -/ diff --git a/Fp/Basic.lean b/Fp/Basic.lean index f70569c..3397636 100644 --- a/Fp/Basic.lean +++ b/Fp/Basic.lean @@ -879,6 +879,9 @@ def bias (e : Nat) : Nat := namespace PackedFloat +def mkNaN (sign := false) (sig := 1#s <<< (s - 1)) : PackedFloat e s := + { sign, ex := BitVec.allOnes e, sig } + def toExtDyadic (pf : PackedFloat e s) : ExtDyadic := bif pf.isNaN then .NaN @@ -962,10 +965,10 @@ def isZero (uf : UnpackedFloat e s) : Bool := uf.ex == 0 && uf.sig == 0 @[bv_normalize] -def normalize (uf : UnpackedFloat e s) : UnpackedFloat e s := - bif uf.sig == 0 then +def normalize (uf : UnpackedFloat e s) (sign := uf.sign) : UnpackedFloat e s := + bif uf.sig == 0#s then -- zero case: make it explicit! - mkZero uf.sign + mkZero sign else { sign := uf.sign diff --git a/Fp/Negation.lean b/Fp/Negation.lean index 73eca29..8e0865d 100644 --- a/Fp/Negation.lean +++ b/Fp/Negation.lean @@ -9,11 +9,11 @@ def UnpackedFloat.abs (x : UnpackedFloat e s) : UnpackedFloat e s := def EUnpackedFloat.neg (x : EUnpackedFloat (exponentWidth e s) (s + 1)) : EUnpackedFloat (exponentWidth e s) (s + 1) := - .mkNumber x.num.neg + { x with num := x.num.neg } def EUnpackedFloat.abs (x : EUnpackedFloat (exponentWidth e s) (s + 1)) : EUnpackedFloat (exponentWidth e s) (s + 1) := - .mkNumber x.num.abs + { x with num := x.num.abs } namespace PackedFloat diff --git a/Fp/Packing.lean b/Fp/Packing.lean index e0efe4d..960611c 100644 --- a/Fp/Packing.lean +++ b/Fp/Packing.lean @@ -36,14 +36,14 @@ def EUnpackedFloat.pack (uf : EUnpackedFloat (exponentWidth e s) (s + 1)) ex := bif uf.isNaN || uf.isInfinite then BitVec.allOnes e else if uf.isZero || !inNormalRange then - (0#_) + 0#e else -- bif uf.isNorm then -- Truncate msbs used to normalize subnormals (uf.exp + BitVec.ofNat _ (bias e)).truncate _ sig := bif uf.isNaN then - BitVec.ofNat _ (2 ^ (s - 1)) + 1#s <<< (s - 1) else bif uf.isInfinite || uf.isZero then - (0#_) + 0#s else bif inNormalRange then uf.sig.truncate s -- drop the leading 1 bit else -- bif uf.isSubnorm then @@ -51,7 +51,6 @@ def EUnpackedFloat.pack (uf : EUnpackedFloat (exponentWidth e s) (s + 1)) let shift := BitVec.ofInt _ (minNormalExp e) - uf.exp -- shift, and then truncate to significand width. (uf.sig >>> shift).truncate s - } attribute [bv_normalize] BitVec.zero diff --git a/Main.lean b/Main.lean index e3322bf..3a5449a 100644 --- a/Main.lean +++ b/Main.lean @@ -15,6 +15,10 @@ theorem h (f : FP8Format) : BitVec (1 + f.e + f.m) = BitVec 8 := by simp only [f.h8] end FP8Format +def PackedFloat.toBits' (pf : PackedFloat e s) (normNaN : Bool := true) := + let pf := if pf.isNaN && normNaN then .mkNaN else pf + pf.toBits + def toDigits (b : BitVec n) : String := let b' := b.reverse String.join ((List.finRange n).map (fun i => b'[i].toNat.digitChar.toString)) @@ -33,7 +37,7 @@ def test_add (f : FP8Format) (m : RoundingMode) (a b : BitVec 8) : OpResult := { oper := "add" mode := m - result := [a, b, f.h.mp (PackedFloat.add m a' b').toBits].map toDigits + result := [a, b, f.h.mp (PackedFloat.add m a' b').toBits'].map toDigits } def test_sub (f : FP8Format) (m : RoundingMode) (a b : BitVec 8) : OpResult := @@ -42,7 +46,7 @@ def test_sub (f : FP8Format) (m : RoundingMode) (a b : BitVec 8) : OpResult := { oper := "sub" mode := m - result := [a, b, f.h.mp (PackedFloat.sub m a' b').toBits].map toDigits + result := [a, b, f.h.mp (PackedFloat.sub m a' b').toBits'].map toDigits } def test_div (f : FP8Format) (m : RoundingMode) (a b : BitVec 8) : OpResult := @@ -51,7 +55,7 @@ def test_div (f : FP8Format) (m : RoundingMode) (a b : BitVec 8) : OpResult := { oper := "div" mode := m - result := [a, b, f.h.mp (PackedFloat.div m a' b' ).toBits].map toDigits + result := [a, b, f.h.mp (PackedFloat.div m a' b' ).toBits'].map toDigits } def test_mul (f : FP8Format) (m : RoundingMode) (a b : BitVec 8) : OpResult := @@ -60,7 +64,7 @@ def test_mul (f : FP8Format) (m : RoundingMode) (a b : BitVec 8) : OpResult := { oper := "mul" mode := m - result := [a, b, f.h.mp (PackedFloat.mul m a' b').toBits].map toDigits + result := [a, b, f.h.mp (PackedFloat.mul m a' b').toBits'].map toDigits } def test_lt (f : FP8Format) (m : RoundingMode) (a b : BitVec 8) : OpResult := @@ -78,7 +82,7 @@ def test_min (f : FP8Format) (m : RoundingMode) (a b : BitVec 8) : OpResult := { oper := "min" mode := m - result := [a, b, f.h.mp (flt_min a' b').toBits].map toDigits + result := [a, b, f.h.mp (flt_min a' b').toBits'].map toDigits } def test_max (f : FP8Format) (m : RoundingMode) (a b : BitVec 8) : OpResult := @@ -87,7 +91,7 @@ def test_max (f : FP8Format) (m : RoundingMode) (a b : BitVec 8) : OpResult := { oper := "max" mode := m - result := [a, b, f.h.mp (flt_max a' b').toBits].map toDigits + result := [a, b, f.h.mp (flt_max a' b').toBits'].map toDigits } def test_neg (f : FP8Format) (m : RoundingMode) (a : BitVec 8) : OpResult := @@ -95,7 +99,7 @@ def test_neg (f : FP8Format) (m : RoundingMode) (a : BitVec 8) : OpResult := { oper := "neg" mode := m - result := [a, 0#8, f.h.mp (PackedFloat.toBits a'.neg)].map toDigits + result := [a, 0#8, f.h.mp (PackedFloat.toBits' a'.neg)].map toDigits } def test_abs (f : FP8Format) (m : RoundingMode) (a : BitVec 8) : OpResult := @@ -103,7 +107,7 @@ def test_abs (f : FP8Format) (m : RoundingMode) (a : BitVec 8) : OpResult := { oper := "abs" mode := m - result := [a, 0#8, f.h.mp (PackedFloat.toBits a'.abs)].map toDigits + result := [a, 0#8, f.h.mp (PackedFloat.toBits' a'.abs)].map toDigits } def test_roundToInt (f : FP8Format) (m : RoundingMode) (a : BitVec 8) : OpResult := @@ -111,7 +115,7 @@ def test_roundToInt (f : FP8Format) (m : RoundingMode) (a : BitVec 8) : OpResult { oper := "roundToInt" mode := m - result := [a, 0#8, f.h.mp (PackedFloat.toBits (roundToInt m a'))].map toDigits + result := [a, 0#8, f.h.mp (PackedFloat.toBits' (roundToInt m a'))].map toDigits } def test_sqrt (f : FP8Format) (m : RoundingMode) (a : BitVec 8) : OpResult := @@ -119,7 +123,7 @@ def test_sqrt (f : FP8Format) (m : RoundingMode) (a : BitVec 8) : OpResult := { oper := "sqrt" mode := m - result := [a, 0#8, f.h.mp (PackedFloat.toBits (sqrt a' m))].map toDigits + result := [a, 0#8, f.h.mp (PackedFloat.toBits' (sqrt a' m))].map toDigits } def test_rem (f : FP8Format) (m : RoundingMode) (a b : BitVec 8) : OpResult := @@ -128,7 +132,7 @@ def test_rem (f : FP8Format) (m : RoundingMode) (a b : BitVec 8) : OpResult := { oper := "rem" mode := m - result := [a, b, f.h.mp (remainder a' b').toBits].map toDigits + result := [a, b, f.h.mp (remainder a' b').toBits'].map toDigits } def test_binop (f : RoundingMode → BitVec 8 → BitVec 8 → OpResult) : Thunk (List OpResult) := @@ -177,7 +181,7 @@ def test_fma (f : FP8Format) (m : RoundingMode) (a b c : BitVec 8) : OpResult := { oper := "fma" mode := m - result := [a, b, c, f.h.mp (fma a' b' c' m).toBits].map toDigits + result := [a, b, c, f.h.mp (fma a' b' c' m).toBits'].map toDigits } def test_ternop (f : RoundingMode → BitVec 8 → BitVec 8 → BitVec 8 → OpResult) (_ : Unit) : Thunk (List OpResult) :=