Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 101 additions & 82 deletions Fp/Addition.lean
Original file line number Diff line number Diff line change
@@ -1,92 +1,111 @@
import Fp.Basic
import Fp.Rounding
import Fp.UnpackedRound

/--
Addition of two fixed-point numbers.

When the sum is zero, the sign of the zero is dependent on the provided
rounding mode.
-/
@[bv_normalize]
def f_add (mode : RoundingMode) (a b : FixedPoint w e) : FixedPoint (w+1) e :=
let hExOffset : e < w+1 := by
exact Nat.lt_add_right 1 a.hExOffset
let ax := BitVec.setWidth' (by omega) a.val
let bx := BitVec.setWidth' (by omega) b.val
if a.sign == b.sign then
-- Addition of same-signed numbers always preserves sign
{
sign := a.sign
val := BitVec.add ax bx
hExOffset := hExOffset
}
else if BitVec.ult ax bx then
{
sign := b.sign
val := BitVec.sub bx ax
hExOffset := hExOffset
}
else if BitVec.ult bx ax then
def UnpackedFloat.add (sign : Bool) (x y : UnpackedFloat e s) : UnpackedFloat (e + 1) (s + 2) :=
-- Compute absolute exponent difference to determine significant shift amount.
let expDiff : BitVec (e + 1) := x.ex.signExtend (e + 1) - y.ex.signExtend (e + 1)
let absExpDiff := bif expDiff.msb then -expDiff else expDiff
-- Determine the smaller number whose significant we are going to shift.
let (x, y) := bif expDiff.msb || absExpDiff == 0 && x.sig.ult y.sig then (y, x) else (x, y)
-- Extend by 1 bit to the left to account for overflow and two bits to the right to account for
-- round and sticky bits.
let xSig := x.sig.setWidth' (by omega) ++ 0#2
let ySig := y.sig.setWidth' (by omega) ++ 0#2
-- Reuse the same circuit for both addition and subtraction by negating the smaller significant.
-- Note: we always treat significants as unsigned integers. However, we make an exception here
-- to reuse the same adder circuit for both addition and subtraction.
let ySig := bif x.sign == y.sign then ySig else -ySig
-- Cap the right shift at `s + 3` in case `absExpDiff` is too big.
let shiftAmount := bif absExpDiff.ult (s + 3) then absExpDiff.setWidth (s + 3) else (s + 3)
-- Note: we use signed shift for `ySig` here to preserve its sign since it's now a signed integer.
let sigSum : BitVec (s + 3) := xSig + ySig.sshiftRight' shiftAmount
-- Sticky bit depends on bits we lose when we right shift `ySig` and `sigSum` (in case of an overflow).
let sticky := ySig &&& shiftAmount.orderEncode != 0 || sigSum.msb && sigSum.getLsb 0
let sum : UnpackedFloat (e + 1) (s + 2) :=
{
sign := a.sign
val := BitVec.sub ax bx
hExOffset := hExOffset
-- Sign of sum is sign of the bigger number!
sign := x.sign
-- Exponent of sum is exponent of bigger number (`+1` if there is an overflow).
ex := x.ex.signExtend (e + 1) + (BitVec.ofBool sigSum.msb).setWidth' (by omega)
-- Renormalize `sigSum` if there is an overflow.
sig := (sigSum >>> BitVec.ofBool sigSum.msb).truncate (s + 2)
}
-- If a catastrophic cancellation occured, we have to normalize. In case the sum is `0` (i.e., full
-- cancellation), the sign depends on the rounding mode.
let normSum := bif !sum.sig.msb then sum.normalize sign else sum
-- Sticky bit is independent of normalization: add it at the very end.
{ normSum with sig := normSum.sig ||| (BitVec.ofBool sticky).setWidth' (by omega) }

def EUnpackedFloat.add (m : RoundingMode) (x y : EUnpackedFloat (exponentWidth e s) (s + 1))
: EUnpackedFloat (exponentWidth e s) (s + 1) :=
bif x.isZero && !y.isZero then
y
else bif !x.isZero && y.isZero then
x
else bif x.isNaN || y.isNaN || x.isInfinite && y.isInfinite && x.sign != y.sign then
.mkNaN
else bif x.isInfinite && y.isInfinite && x.sign == y.sign ||
x.isInfinite && !y.isInfinite || !x.isInfinite && y.isInfinite then
.mkInfinity (bif x.isInfinite then x.sign else y.sign)
else bif x.isZero && y.isZero then
.mkZero (bif m == .RTN then x.sign || y.sign else x.sign && y.sign)
else
-- Signs are different but values are same, so return +0.0
-- When rounding mode is RTN we should instead return -0.0
{
sign := mode = .RTN
val := 0#_
hExOffset := hExOffset
}
UnpackedFloat.round (.add (m == .RTN) x.num y.num) m

/--
Addition of two extended fixed-point numbers.
namespace PackedFloat

When the sum is zero, the sign of the zero is dependent on the provided
rounding mode.
-/
@[bv_normalize]
def e_add (mode : RoundingMode) (a b : EFixedPoint w e) : EFixedPoint (w+1) e :=
open EFixedPoint in
let hExOffset : e < w + 1 := by
exact Nat.lt_add_right 1 a.num.hExOffset
-- As of 2025-04-14, bv_decide does not support pattern matches on more than
-- one variable, so we'll have to deal with if-statements for now
if hN : a.state = .NaN || b.state = .NaN then getNaN hExOffset
else if hI1 : a.state = .Infinity && b.state = .Infinity then
if a.num.sign == b.num.sign then getInfinity a.num.sign hExOffset
else getNaN hExOffset
else if hI2 : a.state = .Infinity then getInfinity a.num.sign hExOffset
else if hI3 : b.state = .Infinity then getInfinity b.num.sign hExOffset
else
-- is this how to do assertions?
let _ : a.state = .Number && b.state = .Number := by
cases ha : a.state <;> cases hb : b.state <;> simp_all
{
state := .Number
num := f_add mode a.num b.num
}
def add (m : RoundingMode) (x y : PackedFloat e s) : PackedFloat e s :=
(EUnpackedFloat.add m x.unpack y.unpack).pack

instance : Add (PackedFloat e s) where
add := .add .RNE

end PackedFloat

-- Minor cancellation with rounding

/-- info: ExtRat.Number (5 : Rat)/64 -/
#guard_msgs in #eval (PackedFloat.ofBits 3 4 0b00000101).toExtRat
/-- info: ExtRat.Number -2 -/
#guard_msgs in #eval (PackedFloat.ofBits 3 4 0b11000000).toExtRat
/-- info: -123 / 64 -/
#guard_msgs in #eval (5 : Rat)/64 + -2
/-- info: ExtRat.Number (-31 : Rat)/16 -/
#guard_msgs in #eval (PackedFloat.ofRat 3 4 .RNE (-123) 64).toExtRat
/-- info: ExtRat.Number (-31 : Rat)/16 -/
#guard_msgs in #eval (PackedFloat.add .RNE (PackedFloat.ofBits 3 4 0b00000101) (PackedFloat.ofBits 3 4 0b11000000)).toExtRat

-- Minor cancellation without rounding

/-- info: ExtRat.Number (5 : Rat)/64 -/
#guard_msgs in #eval (PackedFloat.ofBits 3 4 0b00000101).toExtRat
/-- info: ExtRat.Number -4 -/
#guard_msgs in #eval (PackedFloat.ofBits 3 4 0b11010000).toExtRat
/-- info: -251 / 64 -/
#guard_msgs in #eval (5 : Rat)/64 + -4
/-- info: ExtRat.Number (-31 : Rat)/8 -/
#guard_msgs in #eval (PackedFloat.ofRat 3 4 .RNE (-251) 64).toExtRat
/-- info: ExtRat.Number (-31 : Rat)/8 -/
#guard_msgs in #eval (PackedFloat.add .RNE (PackedFloat.ofBits 3 4 0b00000101) (PackedFloat.ofBits 3 4 0b11010000)).toExtRat

/--
Addition of two floating point numbers, rounded to a floating point number
using the provided rounding mode.
-/
@[bv_normalize]
def add (a b : PackedFloat e s) (mode : RoundingMode) : PackedFloat e s :=
EFixedPoint.round _ _ mode (e_add mode a.toEFixed b.toEFixed)
-- Rounding Modes

-- Proof by brute force (it takes a while)
/-
theorem PackedFloat_add_comm' (m : RoundingMode) (a b : PackedFloat 5 2)
: (add a b m) = (add b a m) := by
apply PackedFloat.inj
simp [add, e_add, f_add, round, PackedFloat.toEFixed, -BitVec.shiftLeft_eq', -BitVec.ushiftRight_eq']
bv_decide
-/
/-- info: ExtRat.Number (-1 : Rat)/16384 -/
#guard_msgs in #eval (PackedFloat.ofBits 5 2 0b10000100#8).toExtRat
/-- info: ExtRat.Number (5 : Rat)/8192 -/
#guard_msgs in #eval (PackedFloat.ofBits 5 2 0b00010001#8).toExtRat
/-- info: 9 / 16384 -/
#guard_msgs in #eval (-1 : Rat)/16384 + (5 : Rat)/8192
/-- info: ExtRat.Number (1 : Rat)/2048 -/
#guard_msgs in #eval (PackedFloat.ofRat 5 2 .RNE 9 16384).toExtRat
/-- info: ExtRat.Number (5 : Rat)/8192 -/
#guard_msgs in #eval (PackedFloat.ofRat 5 2 .RNA 9 16384).toExtRat
/-- info: ExtRat.Number (1 : Rat)/2048 -/
#guard_msgs in #eval (PackedFloat.add .RNE (PackedFloat.ofBits 5 2 0b10000100#8) (PackedFloat.ofBits 5 2 0b00010001#8)).toExtRat
/-- info: ExtRat.Number (5 : Rat)/8192 -/
#guard_msgs in #eval (PackedFloat.add .RNA (PackedFloat.ofBits 5 2 0b10000100#8) (PackedFloat.ofBits 5 2 0b00010001#8)).toExtRat

/-- info: { sign := +, ex := 0x04#5, sig := 0x0#2 } -/
#guard_msgs in #eval add (PackedFloat.ofBits 5 2 0b10000100#8) (PackedFloat.ofBits 5 2 0b00010001#8) .RNE
/-- info: { sign := +, ex := 0x04#5, sig := 0x1#2 } -/
#guard_msgs in #eval add (PackedFloat.ofBits 5 2 0b10000100#8) (PackedFloat.ofBits 5 2 0b00010001#8) .RNA
/-- info: ExtRat.Infinity true -/
#guard_msgs in #eval (PackedFloat.add .RNE (PackedFloat.ofBits 5 2 0b11111100#8) (PackedFloat.ofBits 5 2 0b00010001#8)).toExtRat
/-- info: ExtRat.Infinity false -/
#guard_msgs in #eval (PackedFloat.add .RNE (PackedFloat.ofBits 5 2 0b01111100#8) (PackedFloat.ofBits 5 2 0b00010001#8)).toExtRat
9 changes: 6 additions & 3 deletions Fp/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,9 @@ def bias (e : Nat) : Nat :=

namespace PackedFloat

def mkNaN (sign := false) (sig := 1#s <<< (s - 1)) : PackedFloat e s :=
{ sign, ex := BitVec.allOnes e, sig }

def toExtDyadic (pf : PackedFloat e s) : ExtDyadic :=
bif pf.isNaN then
.NaN
Expand Down Expand Up @@ -962,10 +965,10 @@ def isZero (uf : UnpackedFloat e s) : Bool :=
uf.ex == 0 && uf.sig == 0

@[bv_normalize]
def normalize (uf : UnpackedFloat e s) : UnpackedFloat e s :=
bif uf.sig.clz == s then
def normalize (uf : UnpackedFloat e s) (sign := uf.sign) : UnpackedFloat e s :=
bif uf.sig == 0#s then
-- zero case: make it explicit!
mkZero uf.sign
mkZero sign
else
{
sign := uf.sign
Expand Down
68 changes: 68 additions & 0 deletions Fp/FMA.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,74 @@ import Fp.Rounding
import Fp.Addition
import Fp.Multiplication

/--
Addition of two fixed-point numbers.

When the sum is zero, the sign of the zero is dependent on the provided
rounding mode.
-/
@[bv_normalize]
def f_add (mode : RoundingMode) (a b : FixedPoint w e) : FixedPoint (w+1) e :=
let hExOffset : e < w+1 := by
exact Nat.lt_add_right 1 a.hExOffset
let ax := BitVec.setWidth' (by omega) a.val
let bx := BitVec.setWidth' (by omega) b.val
if a.sign == b.sign then
-- Addition of same-signed numbers always preserves sign
{
sign := a.sign
val := BitVec.add ax bx
hExOffset := hExOffset
}
else if BitVec.ult ax bx then
{
sign := b.sign
val := BitVec.sub bx ax
hExOffset := hExOffset
}
else if BitVec.ult bx ax then
{
sign := a.sign
val := BitVec.sub ax bx
hExOffset := hExOffset
}
else
-- Signs are different but values are same, so return +0.0
-- When rounding mode is RTN we should instead return -0.0
{
sign := mode = .RTN
val := 0#_
hExOffset := hExOffset
}

/--
Addition of two extended fixed-point numbers.

When the sum is zero, the sign of the zero is dependent on the provided
rounding mode.
-/
@[bv_normalize]
def e_add (mode : RoundingMode) (a b : EFixedPoint w e) : EFixedPoint (w+1) e :=
open EFixedPoint in
let hExOffset : e < w + 1 := by
exact Nat.lt_add_right 1 a.num.hExOffset
-- As of 2025-04-14, bv_decide does not support pattern matches on more than
-- one variable, so we'll have to deal with if-statements for now
if hN : a.state = .NaN || b.state = .NaN then getNaN hExOffset
else if hI1 : a.state = .Infinity && b.state = .Infinity then
if a.num.sign == b.num.sign then getInfinity a.num.sign hExOffset
else getNaN hExOffset
else if hI2 : a.state = .Infinity then getInfinity a.num.sign hExOffset
else if hI3 : b.state = .Infinity then getInfinity b.num.sign hExOffset
else
-- is this how to do assertions?
let _ : a.state = .Number && b.state = .Number := by
cases ha : a.state <;> cases hb : b.state <;> simp_all
{
state := .Number
num := f_add mode a.num b.num
}

@[bv_normalize]
def fma (a b c : PackedFloat e s) (m : RoundingMode)
: PackedFloat e s :=
Expand Down
99 changes: 60 additions & 39 deletions Fp/Negation.lean
Original file line number Diff line number Diff line change
@@ -1,40 +1,61 @@
import Fp.Basic
import Fp.Rounding

/-- Negate the fixed point number -/
@[bv_normalize]
def f_neg (a : FixedPoint w e) : FixedPoint w e := { a with sign := !a.sign }

/-- Negate the extended fixed point number -/
@[bv_normalize]
def e_neg (a : EFixedPoint w e) : EFixedPoint w e :=
open EFixedPoint in
have := a.num.hExOffset
if hN : a.state = .NaN then
getNaN (by omega)
else if hInf : a.state = .Infinity then
getInfinity (!a.num.sign) (by omega)
else
let _ : a.state = .Number := by
cases h : a.state <;> simp_all
{ a with num := f_neg a.num }

/-- Negate a floating-point number, by conversion to a fixed-point number. -/
@[bv_normalize]
def negfixed (a : PackedFloat e s) (mode : RoundingMode) : PackedFloat e s :=
EFixedPoint.round _ _ mode (e_neg a.toEFixed)

/--
Negate a floating-point number, by flipping the sign bit.

This implements the same function as `negfixed`, but is much simpler.
-/
@[bv_normalize]
def neg (a : PackedFloat e s) : PackedFloat e s :=
if a.isNaN then PackedFloat.getNaN _ _
else { a with sign := !a.sign }

@[bv_normalize]
def abs (a : PackedFloat e s) : PackedFloat e s :=
if a.isNaN then PackedFloat.getNaN _ _
else { a with sign := false }
import Fp.Packing

def UnpackedFloat.neg (x : UnpackedFloat e s) : UnpackedFloat e s :=
{ x with sign := !x.sign }

def UnpackedFloat.abs (x : UnpackedFloat e s) : UnpackedFloat e s :=
{ x with sign := false }

def EUnpackedFloat.neg (x : EUnpackedFloat (exponentWidth e s) (s + 1))
: EUnpackedFloat (exponentWidth e s) (s + 1) :=
{ x with num := x.num.neg }

def EUnpackedFloat.abs (x : EUnpackedFloat (exponentWidth e s) (s + 1))
: EUnpackedFloat (exponentWidth e s) (s + 1) :=
{ x with num := x.num.abs }

namespace PackedFloat

def neg (x : PackedFloat e s) : PackedFloat e s :=
-- At first glance, unpacking a float just to modify its sign might look like
-- unnecessary work; after all, you *can* toggle the sign bit directly on the
-- packed representation. And if this operation lived in isolation, that would
-- indeed be the simpler route.
--
-- However, most floating‑point operations already require unpacking, and once
-- you're working in that representation, it's usually best to stay there.
-- Thanks to the identity
--
-- `∀ uf m, (uf.round m).pack.unpack = uf.round m`
--
-- we can keep everything in the unpacked form throughout a chain of
-- operations and only repack at the very end. Unpacking here helps maintain
-- that workflow and avoids bouncing back and forth between representations.
--
-- TODO: is there a way to hint this optimization strategy to the compiler?
x.unpack.neg.pack

instance : Neg (PackedFloat e s) where
neg := .neg

def abs (x : PackedFloat e s) : PackedFloat e s :=
-- At first glance, unpacking a float just to modify its sign might look like
-- unnecessary work; after all, you *can* toggle the sign bit directly on the
-- packed representation. And if this operation lived in isolation, that would
-- indeed be the simpler route.
--
-- However, most floating‑point operations already require unpacking, and once
-- you're working in that representation, it's usually best to stay there.
-- Thanks to the identity
--
-- `∀ uf m, (uf.round m).pack.unpack = uf.round m`
--
-- we can keep everything in the unpacked form throughout a chain of
-- operations and only repack at the very end. Unpacking here helps maintain
-- that workflow and avoids bouncing back and forth between representations.
--
-- TODO: is there a way to hint this optimization strategy to the compiler?
x.unpack.abs.pack

end PackedFloat
Loading