diff --git a/Strata/DDM/AST.lean b/Strata/DDM/AST.lean index c8c46aeb8..b40dab2c9 100644 --- a/Strata/DDM/AST.lean +++ b/Strata/DDM/AST.lean @@ -10,7 +10,8 @@ public import Strata.DDM.Util.ByteArray public import Strata.DDM.Util.Decimal import Std.Data.HashMap -import Strata.DDM.Util.Array +import all Strata.DDM.Util.Array +import all Strata.DDM.Util.ByteArray set_option autoImplicit false @@ -29,7 +30,7 @@ namespace QualifiedIdent def fullName (i : QualifiedIdent) : String := s!"{i.dialect}.{i.name}" instance : ToString QualifiedIdent where - toString := private fullName + toString := fullName section open _root_.Lean @@ -554,12 +555,14 @@ structure DebruijnIndex (n : Nat) where isLt : val < n deriving Repr - namespace DebruijnIndex def toLevel {n} : DebruijnIndex n → Fin n | ⟨v, lt⟩ => ⟨n - (v+1), by omega⟩ +protected def ofNat {n : Nat} [NeZero n] (a : Nat) : DebruijnIndex n := + ⟨a % n, Nat.mod_lt _ (Nat.pos_of_neZero n)⟩ + end DebruijnIndex /-- diff --git a/Strata/DDM/Elab.lean b/Strata/DDM/Elab.lean index 3be340762..2dbbd7208 100644 --- a/Strata/DDM/Elab.lean +++ b/Strata/DDM/Elab.lean @@ -12,6 +12,8 @@ import Strata.DDM.BuiltinDialects import Strata.DDM.Util.Ion.Serialize import Strata.Util.IO +import all Strata.DDM.Util.ByteArray + open Lean (Message) open Strata.Parser (InputContext) diff --git a/Strata/DDM/Elab/Core.lean b/Strata/DDM/Elab/Core.lean index 0eace2672..fe355e11d 100644 --- a/Strata/DDM/Elab/Core.lean +++ b/Strata/DDM/Elab/Core.lean @@ -8,8 +8,8 @@ module public import Strata.DDM.Elab.DeclM public import Strata.DDM.Elab.Tree -import Strata.DDM.Util.Array -import Strata.DDM.Util.Fin +import all Strata.DDM.Util.Array +import all Strata.DDM.Util.Fin import Strata.DDM.HNF open Lean ( diff --git a/Strata/DDM/Elab/DialectM.lean b/Strata/DDM/Elab/DialectM.lean index a2e9baa32..b0c2ff3d0 100644 --- a/Strata/DDM/Elab/DialectM.lean +++ b/Strata/DDM/Elab/DialectM.lean @@ -9,8 +9,8 @@ public import Strata.DDM.AST public import Strata.DDM.Elab.Core import Std.Data.HashMap -import Strata.DDM.Util.Array -import Strata.DDM.Util.Fin +import all Strata.DDM.Util.Array +import all Strata.DDM.Util.Fin set_option autoImplicit false diff --git a/Strata/DDM/Format.lean b/Strata/DDM/Format.lean index b44486407..b7ebb5adb 100644 --- a/Strata/DDM/Format.lean +++ b/Strata/DDM/Format.lean @@ -5,14 +5,11 @@ -/ module +import Std.Data.HashSet public import Strata.DDM.AST - import Strata.DDM.Util.Format import Strata.DDM.Util.Nat -import Strata.DDM.Util.String -import Std.Data.HashSet - -meta import Strata.DDM.AST +import all Strata.DDM.Util.String open Std (Format format) @@ -114,7 +111,7 @@ private def fvarName (ctx : FormatContext) (idx : FreeVarIndex) : String := else s!"fvar!{idx}" -protected def ofDialects (dialects : DialectMap) (globalContext : GlobalContext) (opts : FormatOptions) : FormatContext where +protected def ofDialects (dialects : DialectMap) (globalContext : GlobalContext := {}) (opts : FormatOptions := {}) : FormatContext where opts := opts getFnDecl sym := Id.run do let .function f := dialects.decl! sym @@ -444,13 +441,13 @@ private partial def OperationF.mformatM (op : OperationF α) : FormatM PrecForma end -instance Expr.instToStrataFormat : ToStrataFormat Expr where +instance Expr.instToStrataFormat {α} : ToStrataFormat (ExprF α) where mformat e c s := private e.mformatM #[] c s |>.fst -instance Arg.instToStrataFormat : ToStrataFormat Arg where +instance Arg.instToStrataFormat {α} : ToStrataFormat (ArgF α) where mformat a c s := private a.mformatM c s |>.fst -instance Operation.instToStrataFormat : ToStrataFormat Operation where +instance Operation.instToStrataFormat {α} : ToStrataFormat (OperationF α) where mformat o c s := private o.mformatM c s |>.fst namespace MetadataArg diff --git a/Strata/DDM/HNF.lean b/Strata/DDM/HNF.lean index df24c9be4..811add501 100644 --- a/Strata/DDM/HNF.lean +++ b/Strata/DDM/HNF.lean @@ -6,7 +6,7 @@ module public import Strata.DDM.AST -import Strata.DDM.Util.Array +import all Strata.DDM.Util.Array public section namespace Strata diff --git a/Strata/DDM/Integration/Lean.lean b/Strata/DDM/Integration/Lean.lean index 8b400b78d..eccc1435f 100644 --- a/Strata/DDM/Integration/Lean.lean +++ b/Strata/DDM/Integration/Lean.lean @@ -3,6 +3,7 @@ SPDX-License-Identifier: Apache-2.0 OR MIT -/ +module -import Strata.DDM.Integration.Lean.Gen -import Strata.DDM.Integration.Lean.HashCommands +public import Strata.DDM.Integration.Lean.Gen +public import Strata.DDM.Integration.Lean.HashCommands diff --git a/Strata/DDM/Integration/Lean/BoolConv.lean b/Strata/DDM/Integration/Lean/BoolConv.lean index ce83f5902..0c371790c 100644 --- a/Strata/DDM/Integration/Lean/BoolConv.lean +++ b/Strata/DDM/Integration/Lean/BoolConv.lean @@ -11,27 +11,28 @@ public section namespace Strata /-- Convert Init.Bool inductive to OperationF -/ -def Bool.toAst {α} [Inhabited α] (v : Ann Bool α) : OperationF α := - if v.val then - ⟨v.ann, q`Init.boolTrue, #[]⟩ +def OperationF.ofBool {α} (ann : α) (b : Bool) : OperationF α := + if b then + { ann := ann, name := q`Init.boolTrue, args := #[] } else - ⟨v.ann, q`Init.boolFalse, #[]⟩ + { ann := ann, name := q`Init.boolFalse, args := #[] } /-- Convert OperationF to Init.Bool -/ -def Bool.ofAst {α} [Inhabited α] [Repr α] (op : OperationF α) : OfAstM (Ann Bool α) := - match op.name with - | q`Init.boolTrue => - if op.args.size = 0 then - pure ⟨op.ann, true⟩ - else - .error s!"boolTrue expects 0 arguments, got {op.args.size}" - | q`Init.boolFalse => - if op.args.size = 0 then - pure ⟨op.ann, false⟩ - else - .error s!"boolFalse expects 0 arguments, got {op.args.size}" - | _ => - .error s!"Unknown Bool operator: {op.name}" +def Bool.ofAst {α} [Inhabited α] [Repr α] (arg : ArgF α) : OfAstM Bool := do + match arg with + | .op op => + match op.name with + | q`Init.boolTrue => + if op.args.size ≠ 0 then + .error s!"boolTrue expects 0 arguments, got {op.args.size}" + pure true + | q`Init.boolFalse => + if op.args.size ≠ 0 then + .error s!"boolFalse expects 0 arguments, got {op.args.size}" + pure false + | _ => + .error s!"Unknown Bool operator: {op.name}" + | _ => .throwExpected "boolean" arg end Strata end diff --git a/Strata/DDM/Integration/Lean/Gen.lean b/Strata/DDM/Integration/Lean/Gen.lean index 48829ed24..f66ad75f6 100644 --- a/Strata/DDM/Integration/Lean/Gen.lean +++ b/Strata/DDM/Integration/Lean/Gen.lean @@ -3,20 +3,29 @@ SPDX-License-Identifier: Apache-2.0 OR MIT -/ -import Lean.Elab.Command - -import Strata.DDM.BuiltinDialects.Init -import Strata.DDM.BuiltinDialects.StrataDDL -import Strata.DDM.Integration.Categories -import Strata.DDM.Integration.Lean.BoolConv -import Strata.DDM.Integration.Lean.Env -import Strata.DDM.Integration.Lean.GenTrace -import Strata.DDM.Integration.Lean.OfAstM -import Strata.DDM.Util.Graph.Tarjan - -open Lean (Command Name Ident Term TSyntax getEnv logError profileitM quote withTraceNode mkIdentFrom) +module +public meta import Lean.Elab.Command + +public meta import Strata.DDM.AST +public meta import Strata.DDM.BuiltinDialects.Init +public import Strata.DDM.Integration.Lean.BoolConv +public import Strata.DDM.Integration.Lean.GenTrace +public meta import Strata.DDM.Integration.Lean.Env +public meta import Strata.DDM.Util.Graph.Preimage +public meta import Strata.DDM.Util.Graph.Tarjan +public meta import Strata.DDM.Util.OrderedSet + +meta import Strata.DDM.BuiltinDialects.StrataDDL +import all Strata.DDM.Util.Array +import all Strata.DDM.Util.Vector + +import Std.Data.HashSet.Lemmas + +open Lean (Command Name Ident Term TSyntax addAndCompile getEnv logError) +open Lean (mkApp2 mkApp3 mkAppN mkCIdent mkConst mkIdentFrom) +open Lean (profileitM quote withTraceNode) open Lean.Elab (throwUnsupportedSyntax) -open Lean.Elab.Command (CommandElab CommandElabM elabCommand) +open Lean.Elab.Command (CommandElab CommandElabM elabCommand liftCoreM) open Lean.MonadOptions (getOptions) open Lean.MonadResolveName (getCurrNamespace) open Lean.Parser.Command (ctor) @@ -24,136 +33,325 @@ open Lean.Parser.Term (bracketedBinderF doSeqItem matchAltExpr) open Lean.Parser.Termination (terminationBy suffix) open Lean.Syntax (mkApp mkCApp mkStrLit) -namespace Strata +open Strata.DDM (OrderedSet) -namespace Lean +meta section + +namespace Strata.OutGraph /-- -Prepend the current namespace to the Lean name and convert to an identifier. +An index map `m` contains a partial mapping information of +nodes from a graph with `n` nodes to a graph `m._size` +nodes. Nodes in the source graph should either map to a unique +node in the target graph or `0` if they are not included. -/ -def mkScopedIdent (scope : Name) (subName : Lean.Name) : Ident := - let fullName := scope ++ subName - let nameStr := toString subName - .mk (.ident .none nameStr.toSubstring subName [.decl fullName []]) +structure PartialNodeMapping (n : Nat) where + /-- + Map from indices in source graph to target graph to either + `0` if the node is droped or `idx + 1` if the node is in the + target graph. + -/ + embed : Vector Nat n + /-- + Reverse map from target graph to source graph. + + N.B. The number of elements is the size of the target graph. + -/ + rev : Array Nat + /-- + Constraint that all reverse node maps are in bound. + -/ + revBound : ∀(m : Nat), m ∈ rev → m < n + /-- + Constraint that all + -/ + embedBound : ∀(m : Nat), m ∈ embed → m ≤ rev.size + +namespace PartialNodeMapping + +abbrev targetSize {n} (im : PartialNodeMapping n) : Nat := im.rev.size + +def empty (source_capacity : Nat) : PartialNodeMapping 0 := { + embed := Vector.emptyWithCapacity source_capacity + rev := #[] + revBound := by grind + embedBound := fun m mem => by simp [Vector.emptyWithCapacity] at mem; +} /-- -Prepend the current namespace to the Lean name and convert to an identifier. +Return index of source node in targt graph. -/ -def currScopedIdent {m} [Monad m] [Lean.MonadResolveName m] (subName : Lean.Name) : m Ident := do - (mkScopedIdent · subName) <$> getCurrNamespace +def sourceNode {n} (m : PartialNodeMapping n) (i : Nat) + (p : i < m.targetSize := by get_elem_tactic) : Fin n := + ⟨m.rev[i], m.revBound _ (by grind)⟩ + +def cast {m n} (em : PartialNodeMapping m) (eq : m = n) : PartialNodeMapping n := + have q : PartialNodeMapping m = PartialNodeMapping n := congrArg PartialNodeMapping eq + q.mp em + +def embedFn {n} (em : PartialNodeMapping n) (i : Fin n) : Fin (em.targetSize + 1) := + let e := em.embed[i] + have mem : e ∈ em.embed := by simp [e] + have _ : e ≤ em.targetSize := em.embedBound _ mem + have ep : e < em.targetSize + 1 := by omega + ⟨e, ep⟩ + +def extend {n} (em : PartialNodeMapping n) (mem : Bool) : PartialNodeMapping (n+1) := + let {rev, embed, revBound, embedBound } := em + if mem then + let rev' := rev.push embed.size + { rev := rev' + embed := embed.push rev'.size + revBound := by grind + embedBound := by grind + } + else + { rev := rev + embed := embed.push 0 + revBound := fun v mem => by grind + embedBound := fun m mem => by grind + } + +def fromVector.aux {m n} (usedSet : Vector Bool n) (t : PartialNodeMapping m) (ap : m ≤ n := by omega) : PartialNodeMapping n := + if ilt : m < n then + fromVector.aux usedSet (t.extend usedSet[m]) + else + t.cast (by omega) + termination_by n - m -end Lean +def fromVector {n} (usedSet : Vector Bool n) : PartialNodeMapping n := + fromVector.aux usedSet (.empty n) (by grind) -open Lean (currScopedIdent) +end PartialNodeMapping -private def arrayLit [Monad m] [Lean.MonadQuotation m] (as : Array Term) : m Term := do - ``( (#[ $as:term,* ] : Array _) ) +/-- +`g.projection emb` generates the minimal graph `g'` with `n` vertices which for each edge +each edges `⟨s, t⟩ ∈ g`, `g'` contains the edge `⟨emb s - 1, emb t -1⟩` if `emb s, emb t > 0`. -private def vecLit [Monad m] [Lean.MonadQuotation m] (as : Array Term) : m Term := do - ``( (#v[ $as:term,* ] : Vector _ $(quote as.size)) ) +This is a partial homomophism since nodes can be dropped by mapping then to `n`. +-/ +def projection {m n} (g : OutGraph m) (emb : Fin m -> Fin (n+1)) : OutGraph n := + let gn : OutGraph n := { edges := .replicate n #[] } + let addEdge (a : Array (Fin n)) (i : Fin m) := + let ⟨j, jlt⟩ := emb i + if jp : j = 0 then + a + else + a.push ⟨j - 1, by omega⟩ + let appendEdgesToArray (i : Fin m) (a : Array (Fin n)) : Array (Fin n) := + g.edges[i].foldl (init := a) addEdge + let appendEdges (g : OutGraph n) (i : Fin m) : OutGraph n := + let ⟨j, jlt⟩ := emb i + if jp : j = 0 then + gn + else + { edges := gn.edges.modify ⟨j-1, by omega⟩ (appendEdgesToArray i) } + Fin.foldl m (init := gn) appendEdges -abbrev LeanCategoryName := Lean.Name -structure GenContext where - -- Syntax for #strata_gen for source location purposes. - src : Lean.Syntax - categoryNameMap : Std.HashMap QualifiedIdent String - exprHasEta : Bool -abbrev GenM := ReaderT GenContext CommandElabM +end Strata.OutGraph -def runCmd {α} (act : CommandElabM α) : GenM α := fun _ => act +namespace Strata.Lean -/-- Create a fresh name. -/ -private def genFreshLeanName (s : String) : GenM Name := do - let fresh ← modifyGet fun s => (s.nextMacroScope, { s with nextMacroScope := s.nextMacroScope + 1 }) - let n : Name := .anonymous |>.str s - return Lean.addMacroScope (← getEnv).mainModule n fresh +/-- +Prepend the current namespace to the Lean name and convert to an identifier. +-/ +def scopedIdent (scope subName : Lean.Name) : Ident := + let name := scope ++ subName + let nameStr := toString subName + .mk (.ident .none nameStr.toSubstring subName [.decl name []]) -/-- Create a fresh name. -/ -private def genFreshIdentPair (s : String) : GenM (Ident × Ident) := do - let name ← genFreshLeanName s - let src := (←read).src - return (mkIdentFrom src name true, mkIdentFrom src name) +/-- +Prepend the current namespace to the Lean name and convert to an identifier. +-/ +def currScopedIdent {m} [Monad m] [Lean.MonadResolveName m] + (subName : Lean.Name) : m Ident := do + (scopedIdent · subName) <$> getCurrNamespace + +/- Returns an identifier from a string. -/ +def localIdent (name : String) : Ident := + let dName := .anonymous |>.str name + .mk (.ident .none name.toSubstring dName []) /-- Create a canonical identifier. -/ def mkCanIdent (src : Lean.Syntax) (val : Name) : Ident := mkIdentFrom src val true -/-- Create a identifier from a name. -/ -private def genIdentFrom (name : Name) (canonical : Bool := false) : GenM Ident := do - return mkIdentFrom (←read).src name canonical +/-- +Create an identifier to a fully qualified Lean name +-/ +def mkRootIdent (name : Name) : Ident := + let rootName := `_root_ ++ name + .mk (.ident .none name.toString.toSubstring rootName [.decl name []]) -def reservedCats : Std.HashSet String := { "Type" } +end Lean -structure OrderedSet (α : Type _) [BEq α] [Hashable α] where - set : Std.HashSet α - values : Array α +open Lean (currScopedIdent localIdent mkCanIdent mkRootIdent) -namespace OrderedSet +def arrayLit {m} [Monad m] [Lean.MonadQuotation m] (as : Array Term) : m Term := do + ``( (#[ $as:term,* ] : Array _) ) -def empty [BEq α] [Hashable α] : OrderedSet α := { set := {}, values := #[]} +namespace SyntaxCatF -partial def addAtom {α} [BEq α] [Hashable α] (s : OrderedSet α) (a : α) : OrderedSet α := - if a ∈ s.set then - s +/-- +Invoke `f` over all atomic (no argument) category names in `c`. +-/ +def foldOverAtomicCategories {α} + (cat : SyntaxCat) (init : α) (f : α → QualifiedIdent → α) : α := + if cat.args.size = 0 then + f init cat.name else - { set := s.set.insert a, values := s.values.push a } + cat.args.foldl (init := init) fun v a => foldOverAtomicCategories a v f +termination_by cat +decreasing_by + rw [sizeOf_spec cat] + decreasing_tactic -partial def addPostTC {α} [BEq α] [Hashable α] (next : α → Array α) (s : OrderedSet α) (a : α) : OrderedSet α := - if a ∈ s.set then - s +/-- +If we have a an action that is preserves an invariant on all names, then it +is preserved when folding over the atomic categories. +-/ +theorem foldOverAtomicCategories_invariant {α β} + (measure : α → β) + {f : α → QualifiedIdent → α} + (inv : ∀t name, measure (f t name) = measure t) + (cat : SyntaxCat) (init : α) : + measure (cat.foldOverAtomicCategories init f) = measure init := by + + unfold foldOverAtomicCategories + if p : cat.args.size = 0 then + simp [p, inv] else - let as := next a - let s := { s with set := s.set.insert a } - let s := as.foldl (init := s) (addPostTC next) - { s with values := s.values.push a } + simp [p] + apply Array.foldl_induction (motive := fun _ s => measure s = measure init) (h0 := rfl) + intro i b b_eq + simp only [foldOverAtomicCategories_invariant measure inv cat.args[i]] + exact b_eq +termination_by cat +decreasing_by + rw [sizeOf_spec cat] + decreasing_tactic -end OrderedSet +end SyntaxCatF -def generateDependentDialects (lookup : String → Option Dialect) (name : DialectName) : Array DialectName := +/-- +Given a dialect map and a specific dialect, this computes an array +with the imported dialects. The entries are ordered so that all a dialects +imports appear before it in the array. +-/ +def computeImportedDialects (dialects : DialectMap) + (name : DialectName) : Array DialectName := let s : OrderedSet DialectName := .empty - let s := s.addAtom initDialect.name - let next (name : DialectName) : Array DialectName := - match lookup name with + let s := s.insert initDialect.name + let imports (name : DialectName) : Array DialectName := + match dialects[name]? with | some d => d.imports | none => #[] - s.addPostTC next name |>.values - -def resolveDialects (lookup : String → Option Dialect) (dialects : Array DialectName) : Except String (Array Dialect) := do - dialects.mapM fun name => - match lookup name with - | none => throw s!"Unknown dialect {name}" - | some d => pure d + s.addAllPostorder imports name |>.toArray abbrev CategoryName := QualifiedIdent -def forbiddenCategories : Std.HashSet CategoryName := DDM.Integration.forbiddenCategories +/-- +Dialect names that are not allowed. +-/ +def reservedCatNames : Std.HashSet String := { "Type" } -private def forbiddenWellDefined : Bool := - forbiddenCategories.all fun nm => - match nm.dialect with - | "Init" => nm.name ∈ initDialect - | "StrataDDL" => nm.name ∈ StrataDDL +/-- +This returns true if the name is a category in the Init or StrataDDL dialect. +-/ +def builtinsWellDefined (category : QualifiedIdent) : Bool := + match category.dialect with + | "Init" => + match initDialect.cache[category.name]? with + | some (.syncat _) => true + | _ => false + | "StrataDDL" => + match StrataDDL.cache[category.name]? with + | some (.syncat _) => true | _ => false + | _ => false -#guard "BindingType" ∈ initDialect.cache -#guard "Binding" ∈ StrataDDL.cache -#guard forbiddenWellDefined +/-- +This maps category names in the Init dialect that are already declared +to their fully qualified Lean name. +-/ +def declaredCategories : Std.HashMap CategoryName Name := .ofList [ + (q`Init.Ident, ``String), + (q`Init.Num, ``Nat), + (q`Init.Decimal, ``Decimal), + (q`Init.Str, ``String), + (q`Init.ByteArray, ``ByteArray), + (q`Init.Bool, ``Bool) +] + +#guard declaredCategories.all fun nm _ => builtinsWellDefined nm + +/-- +Maps builtin polymorphic categories to their Lean representation +-/ +def polymorphicBuiltinCategories : Std.HashMap QualifiedIdent Name := + .ofList [ + (q`Init.CommaSepBy, ``Array), + (q`Init.Option, ``Option), + (q`Init.Seq, ``Array), + ] + +def polyCatMap : Std.HashMap QualifiedIdent Lean.Expr := .ofList [ + (q`Init.CommaSepBy, .const ``Array [0]), + (q`Init.Option, .const ``Option [0]), + (q`Init.Seq, .const ``Array [0]), +] + +/-- +Privatte categories are categories that should not be directly +used in user dialects. + +Note. As of 1/6/2026, this is not checked during Dialect parsing; +we should fix this. +-/ +def privateCategories : Std.HashSet CategoryName := { + q`Init.TypeExpr, + q`Init.BindingType, + q`StrataDDL.Binding +} +#guard privateCategories.all builtinsWellDefined + +def ignoreCategory (cat : CategoryName) : Bool := + cat ∈ declaredCategories ∨ cat ∈ privateCategories /-- Special categories ignore operations introduced in Init, but are populated with operators via functions/types. + +User dialects should not have operators that extend them, but operators +and functions may reference them. -/ -def specialCategories : Std.HashSet CategoryName := DDM.Integration.abstractCategories +def specialCategories : Std.HashSet CategoryName := { + q`Init.Expr, + q`Init.Type, + q`Init.TypeP +} +#guard specialCategories.all builtinsWellDefined + + +def annTypeExpr (base ann : Lean.Expr) := mkApp2 (mkConst ``Ann) base ann /-- Argument declaration for code generation. -/ -structure GenArgDecl where +structure CatOpArg where name : String cat : SyntaxCat - unwrap : Bool := false + wrap : Bool := true + +/-- +An operation at the category level. +-/ +structure CatOp where + name : QualifiedIdent + argDecls : Array CatOpArg + /-- A constructor in a generated datatype. @@ -171,27 +369,38 @@ structure DefaultCtor where The name in the Strata dialect for this constructor. If `none`, then this must be an auto generated constructor. -/ - strataName : Option QualifiedIdent - argDecls : Array GenArgDecl + strataName : Option QualifiedIdent := none + /-- + Flag indicating the generated constructor should add an annotation field as a + first argument. + -/ + includeAnn : Bool := true + /-- + Argument declarations + -/ + argDecls : Array CatOpArg + /-- + Either annotations are included or there is a single argument we can get + the annotation from. + -/ + includeAnnInvariant : + includeAnn ∨ + if p : argDecls.size = 1 then + (argDecls[0]'(p ▸ Nat.zero_lt_one)).wrap = true + else + false := by simp def DefaultCtor.leanName (c : DefaultCtor) : Name := .str .anonymous c.leanNameStr -/-- -An operation at the category level. --/ -structure CatOp where - name : QualifiedIdent - argDecls : Array GenArgDecl - namespace CatOp partial def checkCat (op : QualifiedIdent) (c : SyntaxCat) : Except String Unit := do c.args.forM (checkCat op) let f := c.name - if f ∈ forbiddenCategories then + if f ∈ privateCategories then throw s!"{op.fullName} refers to unsupported category {f.fullName}." -def ofArgDecl (op : QualifiedIdent) (d : ArgDecl) : Except String GenArgDecl := do +def ofArgDecl (op : QualifiedIdent) (d : ArgDecl) : Except String CatOpArg := do let cat ← match d.kind with | .type tp => @@ -201,7 +410,7 @@ def ofArgDecl (op : QualifiedIdent) (d : ArgDecl) : Except String GenArgDecl := pure c -- Check if unwrap metadata is present let unwrap := q`StrataDDL.unwrap ∈ d.metadata - pure { name := d.ident, cat, unwrap } + pure { name := d.ident, cat, wrap := !unwrap } def ofOpDecl (d : DialectName) (o : OpDecl) : Except String CatOp := do let name := ⟨d, o.name⟩ @@ -220,62 +429,103 @@ def ofFunctionDecl (d : DialectName) (o : FunctionDecl) : Except String CatOp := end CatOp -/-- -This maps names of categories that we are going to declare to -the list of operators in that category. --/ -abbrev CatOpMap := Std.HashMap CategoryName (Array CatOp) - -structure CatOpState where - map : CatOpMap +structure ErrorBundle (α : Type) where + value : α errors : Array String := #[] --- Monad that collects errors from adding declarations. -abbrev CatOpM := StateM CatOpState +def ErrorM (α : Type) := StateM (ErrorBundle α) + deriving Monad -def CatOpM.addError (msg : String) : CatOpM Unit := - modify fun s => { s with errors := s.errors.push msg } +namespace ErrorM -def mkRootIdent (name : Name) : Ident := - let rootName := `_root_ ++ name - .mk (.ident .none name.toString.toSubstring rootName [.decl name []]) +def addError {α} (msg : String) : ErrorM α Unit := + (modify fun s => { s with errors := s.errors.push msg } : StateM _ _) -/-- Maps primitive Init categories to their Lean types. -/ -def declaredCategories : Std.HashMap CategoryName Name := .ofList [ - (q`Init.Ident, ``String), - (q`Init.Num, ``Nat), - (q`Init.Decimal, ``Decimal), - (q`Init.Str, ``String), - (q`Init.ByteArray, ``ByteArray), - (q`Init.Bool, ``Bool) -] +instance {α} : MonadState α (ErrorM α) where + get := (return (←get).value : StateM _ _) + set v := (modify (fun s => { value := v, errors := s.errors}) : StateM _ _) + modifyGet f := modifyGet (m := StateM _) fun ⟨v, errors⟩ => + let (a, v) := f v + (a, ⟨v, errors⟩) -#guard declaredCategories.keys.all (DDM.Integration.primitiveCategories.contains ·) +end ErrorM -def ignoredCategories : Std.HashSet CategoryName := - .ofList declaredCategories.keys ∪ forbiddenCategories +structure CatIndexMap where + idents : Array QualifiedIdent + ops : Vector (Array CatOp) idents.size + map : Std.HashMap QualifiedIdent Nat + inv : ∀(name : QualifiedIdent) (p : name ∈ map), map[name] < idents.size -namespace CatOpMap +namespace CatIndexMap -def addCat (m : CatOpMap) (cat : CategoryName) : CatOpMap := - -- Allow Init.Bool even though it's in ignoredCategories - if cat ∈ ignoredCategories && cat ≠ q`Init.Bool then - m - else - m.insert cat #[] +protected def empty : CatIndexMap where + idents := #[] + ops := #v[] + map := {} + inv := fun name mem => by simp at mem + +instance : EmptyCollection CatIndexMap where + emptyCollection := .empty + +instance : Inhabited CatIndexMap where + default := .empty + +abbrev size (m : CatIndexMap) := m.idents.size + +def indexOf? (cats : CatIndexMap) (cat : QualifiedIdent) + : Option (Fin cats.size) := + match h : cats.map[cat]? with + | none => none + | some idx => + have idx_lt : idx < cats.size := by + have ⟨p, eq⟩ := Std.HashMap.getElem?_eq_some_iff.mp h + simp [← eq] + exact cats.inv cat _ + some ⟨idx, idx_lt⟩ + +-- Monad that collects errors from adding declarations. +abbrev CatIndexM := ErrorM CatIndexMap + +/-- +Add a category to the map. +-/ +def addCat (m : CatIndexMap) (cat : CategoryName) (_ : cat ∉ m.map): CatIndexMap := + let n := m.idents.size + { idents := m.idents.push cat + ops := m.ops.push #[] |>.cast (by simp) + map := m.map.insert cat n + inv := fun name namep => by + simp at namep + if h : cat = name then + simp [h] + omega + else + have inv := m.inv name + grind + } -def addOp (m : CatOpMap) (cat : CategoryName) (op : CatOp) : CatOpMap := - assert! cat ∈ m - m.modify cat (fun a => a.push op) +def addCatM (cat : CategoryName) : CatIndexM Unit := do + if ignoreCategory cat then + pure () + else + let already ← modifyGet fun s => + if mem : cat ∈ s.map then + (true, s) + else + (false, s.addCat cat mem) + if already then + .addError s!"Duplicate category {cat}" -def addCatM (cat : CategoryName) : CatOpM Unit := do - modify fun s => { s with map := s.map.addCat cat } -def addOpM (cat : CategoryName) (op : CatOp) : CatOpM Unit := do - modify fun s => { s with map := s.map.addOp cat op } +def addOpM (cat : CategoryName) (op : CatOp) : CatIndexM Unit := do + match (←get).indexOf? cat with + | none => + .addError s!"Missing operator category {cat}" + | some idx => + modify fun m => { m with ops := Vector.modify! m.ops idx.val (·.push op) } -def addDecl (d : DialectName) (decl : Decl) : CatOpM Unit := - let addCatOp (cat : QualifiedIdent) (act : Except String CatOp) : CatOpM Unit := +def addDecl (d : DialectName) (decl : Decl) : CatIndexM Unit := + let addCatOp (cat : QualifiedIdent) (act : Except String CatOp) : CatIndexM Unit := match act with | .ok op => addOpM cat op @@ -285,11 +535,12 @@ def addDecl (d : DialectName) (decl : Decl) : CatOpM Unit := | .syncat decl => addCatM ⟨d, decl.name⟩ | .op decl => do - -- Allow Init.Bool operators even though Bool is in declaredCategories - let isBoolOp := decl.category == q`Init.Bool && (decl.name == "boolTrue" || decl.name == "boolFalse") - if (decl.category ∈ ignoredCategories ∨ decl.category ∈ specialCategories) && !isBoolOp then + let cat := decl.category + if ignoreCategory cat ∨ cat ∈ specialCategories then + -- Ignored and special category operators are ignored in `Init`, but + -- generate errors when operators are in other dialects. if d ≠ "Init" then - .addError s!"Skipping operation {decl.name} in {d}: {decl.category.fullName} cannot be extended." + .addError s!"Skipping operation {decl.name} in {d}: {cat} cannot be extended." else addCatOp decl.category (CatOp.ofOpDecl d decl) | .type decl => @@ -299,124 +550,232 @@ def addDecl (d : DialectName) (decl : Decl) : CatOpM Unit := | .metadata _ => pure () -def addDialect (d : Dialect) : CatOpM Unit := +def addDialect (d : Dialect) : CatIndexM Unit := d.declarations.forM (addDecl d.name) -/- `CatopMap` with onl initial dialect-/ -protected def init : CatOpMap := - let act := do - addDialect initDialect - let ((), s) := act { map := {}, errors := #[] } +/- `CatopMap` with only initial dialect-/ +protected def init : CatIndexMap := + let ((), s) := addDialect initDialect { value := {}, errors := #[] } if s.errors.size > 0 then panic! s!"Error in Init dialect {s.errors}" else - s.map + s.value -end CatOpMap +end CatIndexMap -def mkCatOpMap (a : Array Dialect) : CatOpMap × Array String := +def mkCatIndexMap (a : Array Dialect) : CatIndexMap × Array String := let act := - a.forM fun d => if d.name = "Init" then pure () else CatOpMap.addDialect d - let ((), s) := act { map := CatOpMap.init, errors := #[] } - (s.map, s.errors) - -/-- -A set of categories. --/ -abbrev CategorySet := Std.HashSet CategoryName + a.forM fun d => + if d.name = "Init" then + pure () + else + CatIndexMap.addDialect d + let ((), s) := act { value := CatIndexMap.init, errors := #[] } + (s.value, s.errors) -namespace SyntaxCatF +structure WorkSet (n : Nat) where + set : Vector Bool n + pending : Array (Fin n) + inv : ∀idx, idx ∈ pending → set[idx.val] = true -/-- -Invoke `f` over all atomic (no argument) category names in `c`. --/ -private -def foldOverAtomicCategories {α} (cat : SyntaxCat) (init : α) (f : α → QualifiedIdent → α) : α := - if cat.args.size = 0 then - f init cat.name - else - cat.args.foldl (init := init) fun v a => foldOverAtomicCategories a v f -decreasing_by - rw [sizeOf_spec cat] - decreasing_tactic +namespace WorkSet -end SyntaxCatF +def remainingCount {n} (s : WorkSet n) : Nat := + (s.set.toArray.filter (!·)).size + s.pending.size -structure WorkSet (α : Type _) [BEq α] [Hashable α] where - set : Std.HashSet α - pending : Array α +def empty {n} : WorkSet n := + { set := .replicate _ false + pending := #[] + inv := fun idx idxp => by simp only [Array.mem_empty_iff] at idxp + } -def WorkSet.ofSet [BEq α] [Hashable α] (set : Std.HashSet α) : WorkSet α where - set := set - pending := set.toArray +def addIdx {n} (s : WorkSet n) (idx : Fin n) : WorkSet n := + if p : s.set[idx] = true then + s + else + let { set, pending, inv } := s + { set := set.set idx true + pending := pending.push idx + inv := by grind + } + +theorem remainingCount_addIdx {n} (s : WorkSet n) (idx : Fin n) : + (s.addIdx idx).remainingCount = s.remainingCount := by + simp only [WorkSet.addIdx] + if h : s.set[idx] = true then + simp [h] + else + have eq_false : s.set[idx.val] = false := + iff_of_eq (Bool.not_eq_true (s.set[idx.val])) |>.mp h + simp [-Vector.size_toArray, -Array.size_set, + remainingCount, Array.size_filter_set, eq_false] + have false_in : false ∈ s.set := ⟨Array.mem_of_getElem eq_false⟩ + have size_pos : (Array.filter (fun x => !x) s.set.toArray).size > 0 := by + simp [-Vector.size_toArray, Array.size_filter_pos_iff, false_in] + omega + + +@[inline] +def pop {n} (s : WorkSet n) (p : s.pending.size > 0) : WorkSet n := + let { set, pending, inv } := s + { set + pending := pending.pop + inv := by + intro idx mem + apply inv + simp [Array.mem_pop_ne p, mem] + } + +theorem remaining_count_pop {m} (s : WorkSet m) (p : s.pending.size > 0) : + (s.pop p).remainingCount + 1 = s.remainingCount := by + simp_all [pop, remainingCount] + grind + +end WorkSet + +def addIdent (m : CatIndexMap) (errors : Array String) + (op : QualifiedIdent) (arg : String) (nm : QualifiedIdent) : Array String := + if nm ∈ declaredCategories then + errors + else if nm ∈ privateCategories then + errors.push s!"{op} {arg} cannot reference private category {nm}." + else if nm ∈ m.map then + errors + else + errors.push s!"{op} {arg} references undeclared category {nm}." -def WorkSet.add [BEq α] [Hashable α] (s : WorkSet α) (a : α) : WorkSet α := - let { set, pending } := s - let (mem, set) := set.containsThenInsert a - let pending := if mem then pending else pending.push a - { set, pending } +def addCatArg (m : CatIndexMap) (errors : Array String) (op : QualifiedIdent) (arg : String) + (cat : SyntaxCat) : Array String := + cat.foldOverAtomicCategories (init := errors) (addIdent m · op arg ·) -def WorkSet.pop [BEq α] [Hashable α] (s : WorkSet α) : Option (WorkSet α × α) := - let { set, pending } := s - if p : pending.size > 0 then - some ({ set, pending := pending.pop }, pending[pending.size -1]) - else - none /-- Add all atomic categories in bindings to set. -/ -private def addArgCategories (s : CategorySet) (args : ArgDecls) : CategorySet := - args.foldl (init := s) fun s b => - b.kind.categoryOf.foldOverAtomicCategories (init := s) (·.insert ·) - -partial def mkUsedCategories.aux (m : CatOpMap) (s : WorkSet CategoryName) : CategorySet := - match s.pop with - | none => s.set - | some (s, c) => - match c with - | q`Init.TypeP => - mkUsedCategories.aux m (s.add q`Init.Type) - | _ => - let ops := m.getD c #[] - let addArgs {α:Type} (f : α → CategoryName → α) (a : α) (op : CatOp) := - op.argDecls.foldl (init := a) fun r arg => arg.cat.foldOverAtomicCategories (init := r) f - let addName (pa : WorkSet CategoryName) (c : CategoryName) := pa.add c - let s := ops.foldl (init := s) (addArgs addName) - mkUsedCategories.aux m s - -def mkUsedCategories (m : CatOpMap) (d : Dialect) : CategorySet := +def addArgCategories (m : CatIndexMap) (errors : Array String) + (op : QualifiedIdent) (args : ArgDecls) : Array String := + args.foldl (init := errors) fun errors b => addCatArg m errors op b.ident b.kind.categoryOf + +def addNewIdent (m:CatIndexMap) (s : Array String) (dialect : String) (nm : QualifiedIdent) : + Array String := + if nm ∈ declaredCategories then + s + else if nm ∈ privateCategories then + s.push s!"{dialect} cannot reference private category {nm}." + else if nm ∉ m.map then + s.push s!"{dialect} references undeclared category {nm}." + else + s + +def checkDialect (m : CatIndexMap) (d : Dialect) : Array String := let dname := d.name - let cats := d.declarations.foldl (init := {}) fun s decl => + d.declarations.foldl (init := #[]) fun s decl => match decl with - | .syncat decl => s.insert ⟨dname, decl.name⟩ + | .syncat decl => + addNewIdent m s dname ⟨dname, decl.name⟩ | .op decl => - let s := s.insert decl.category - let s := addArgCategories s decl.argDecls - s + let opName : QualifiedIdent := ⟨dname, decl.name⟩ + if decl.category ∈ specialCategories then + s.push s!"{opName} extends special category {decl.category}." + else + let s := addNewIdent m s dname decl.category + let s := addArgCategories m s opName decl.argDecls + s | .type _ => - s.insert q`Init.Type + addNewIdent m s dname q`Init.Type | .function decl => - let s := s.insert q`Init.Expr - let s := addArgCategories s decl.argDecls + let s := addNewIdent m s dname q`Init.Expr + let s := addArgCategories m s ⟨dname, decl.name⟩ decl.argDecls s | .metadata _ => s - mkUsedCategories.aux m (.ofSet cats) + +/-- +Convert the category index map into a graph. +-/ +def indexMapToGraph (cats : CatIndexMap) : OutGraph cats.size := + let n := cats.size + let g : OutGraph n := OutGraph.empty n + -- Build map from qualified identifier to index in categories. + let identIndexMap := cats.map + + let addArgIndices (cat : QualifiedIdent) + (opName : QualifiedIdent) + (c : SyntaxCat) + (init : OutGraph n) + (resIdx : Fin n) : OutGraph n := + c.foldOverAtomicCategories (init := init) fun g q => + match cats.indexOf? q with + | some i => g.addEdge i resIdx + | none => g + n.fold (init := g) fun i ip g => Id.run do + let cat := cats.idents[i] + let ops := cats.ops[i] + ops.foldl (init := g) fun g op => + op.argDecls.foldl (init := g) fun g arg => + addArgIndices cat op.name arg.cat g ⟨i, ip⟩ + + +/-- +Return indices of categories that are introduced or extended in this dialect. +-/ +def newCategories (m : CatIndexMap) (d : Dialect) : Std.HashSet (Fin m.size) := + -- Generate set of all categories that appear in this file + let addIdent (s : Std.HashSet (Fin m.size)) (nm : QualifiedIdent) := + match m.indexOf? nm with + | some idx => s.insert idx + | none => s + + let rec addCat (s : Std.HashSet (Fin m.size)) (cat : SyntaxCat) := + if cat.args.size = 0 then + match m.indexOf? cat.name with + | some idx => s.insert idx + | none => s + else + cat.args.foldl (init := s) fun s i => addCat s i + + d.declarations.foldl (init := {}) fun s decl => + match decl with + | .syncat decl => addIdent s ⟨d.name, decl.name⟩ + | .op decl => addIdent s decl.category + | .type _ => addIdent s q`Init.Type + | .function _ => addIdent s q`Init.Expr + | .metadata _ => s + +def mkUsedCategories (m : CatIndexMap) (d : Dialect) : Vector Bool m.size := + -- Generate set of all categories that appear in this file + OutGraph.preimage_closure (indexMapToGraph m) (newCategories m d) def mkStandardCtors (exprHasEta : Bool) (cat : QualifiedIdent) : Array DefaultCtor := match cat with | q`Init.Expr => + let fvar := { + leanNameStr := "fvar" + argDecls := #[{ name := "idx", cat := .atom .none q`Init.Num, wrap := false }] + } if exprHasEta then - #[ - .mk "bvar" none #[{ name := "idx", cat := .atom .none q`Init.Num }], - .mk "lambda" none #[ - { name := "var", cat := .atom .none q`Init.Str }, - { name := "type", cat := .atom .none q`Init.Type }, - { name := "fn", cat := .atom .none cat } - ] + #[fvar, + { leanNameStr := "bvar" + argDecls := #[{ name := "idx", cat := .atom .none q`Init.Num, wrap := false }] + }, + { + leanNameStr := "lambda" + argDecls := #[ + { name := "var", cat := .atom .none q`Init.Str }, + { name := "type", cat := .atom .none q`Init.Type }, + { name := "fn", cat := .atom .none cat } + ] + } ] else - #[] + #[fvar] + | q`Init.TypeP => + #[ + { leanNameStr := "expr", + includeAnn := false, + argDecls := #[{ name := "tp", cat := .atom .none q`Init.Type }] + }, + { leanNameStr := "type", argDecls := #[] } + ] | _ => #[] @@ -444,142 +803,146 @@ def toDefaultOp (s : Std.HashSet String) (op : CatOp) : DefaultCtor := argDecls := op.argDecls } -def CatOpMap.onlyUsedCategories (m : CatOpMap) (d : Dialect) (exprHasEta : Bool) : Array (QualifiedIdent × Array DefaultCtor) := - let usedSet := mkUsedCategories m d - m.fold (init := #[]) fun a cat ops => - if cat ∉ declaredCategories ∧ cat ∈ usedSet then - let usedNames : Std.HashSet String := - match cat with - | q`Init.Expr => { "fvar" } - | _ => {} - let standardCtors := mkStandardCtors exprHasEta cat - let usedNames : Std.HashSet String := - standardCtors.foldl (init := usedNames) fun m c => - assert! c.leanNameStr ∉ m - m.insert c.leanNameStr - let (allCtors, _) := ops.foldl (init := (standardCtors, usedNames)) fun (a, s) op => - let dOp := toDefaultOp s op - (a.push dOp, s.insert dOp.leanNameStr) - a.push (cat, allCtors) - else - a - -/- Returns an identifier from a string. -/ -def localIdent (name : String) : Ident := - let dName := .anonymous |>.str name - .mk (.ident .none name.toSubstring dName []) - -def orderedSyncatGroups (categories : Array (QualifiedIdent × Array DefaultCtor)) : Array (Array (QualifiedIdent × Array DefaultCtor)) := +/-- Array with undeclared categories and their constructors. -/ +abbrev CatOpArray := Array (QualifiedIdent × Array DefaultCtor) + +@[inline] +def ctors (exprHasEta : Bool) (m : CatIndexMap) (i : Fin m.size) : QualifiedIdent × Array DefaultCtor := + let cat := m.idents[i] + let standardCtors := mkStandardCtors exprHasEta cat + let usedNames : Std.HashSet String := {} + let ops := m.ops[i] + let (allCtors, _) := + ops.foldl (init := (standardCtors, usedNames)) fun (a, s) op => + let dOp := toDefaultOp s op + (a.push dOp, s.insert dOp.leanNameStr) + (cat, allCtors) + +def catOpToGraph (categories : CatOpArray) : OutGraph categories.size := let n := categories.size let g : OutGraph n := OutGraph.empty n + -- Build map from qualified identifier to index in categories. let identIndexMap : Std.HashMap QualifiedIdent (Fin n) := n.fold (init := {}) fun i p m => m.insert categories[i].fst ⟨i, p⟩ - let getIndex (nm : QualifiedIdent) : Option (Fin n) := - identIndexMap[nm]? - let addArgIndices (cat : QualifiedIdent) (opName : String) (c : SyntaxCat) (init : OutGraph n) (resIdx : Fin n) : OutGraph n := + + let addArgIndices (cat : QualifiedIdent) + (opName : String) + (c : SyntaxCat) + (init : OutGraph n) + (resIdx : Fin n) : OutGraph n := c.foldOverAtomicCategories (init := init) fun g q => if q ∈ declaredCategories then g else - match getIndex q with + match identIndexMap[q]? with | some i => g.addEdge i resIdx | none => panic! s!"{opName} in {cat} has unknown category {q.fullName}" - let g : OutGraph n := - categories.foldl (init := g) fun g (cat, ops) => Id.run do - let some resIdx := getIndex cat - | panic! s!"Unknown category {cat}" - match cat with - | q`Init.TypeP => - let some typeIdx := getIndex q`Init.Type - | panic! s!"Unknown category Init.Type." - g.addEdge typeIdx resIdx - | _ => - ops.foldl (init := g) fun g op => - op.argDecls.foldl (init := g) fun g arg => - addArgIndices cat op.leanNameStr arg.cat g resIdx - let indices := OutGraph.tarjan g - indices.map (·.map (categories[·])) - -def mkCategoryIdent (scope : Name) (name : Name) : Ident := - let mkDeclName (comp : List Name) : Ident := - let subName := comp.foldl (init := .anonymous) fun r nm => r ++ nm - let sName := toString subName - .mk (.ident .none sName.toSubstring subName [.decl name []]) - - let rec aux : Name → List Name → Ident - | .anonymous, _ => mkRootIdent name - | n@(.num p' v), r => - if scope == n then - mkDeclName r - else - aux p' (.num .anonymous v :: r) - | n@(.str p' v), r => - if scope == n then - mkDeclName r - else - aux p' (.str .anonymous v :: r) - aux name [] + n.fold (init := g) fun i ip g => Id.run do + let (cat, ops) := categories[i] + ops.foldl (init := g) fun g op => + op.argDecls.foldl (init := g) fun g arg => + addArgIndices cat op.leanNameStr arg.cat g ⟨i, ip⟩ -/-- -Prepend the current namespace to the Lean name and convert to an identifier. --/ -def scopedIdent (scope subName : Lean.Name) : Ident := - let name := scope ++ subName - let nameStr := toString subName - .mk (.ident .none nameStr.toSubstring subName [.decl name []]) +structure GenContext where + -- Syntax for #strata_gen for source location purposes. + src : Lean.Syntax + /-- + Maps category identifiers to their relative Lean name. + -/ + categoryNameMap : Std.HashMap QualifiedIdent Name -/-- -Prepend the current namespace to the Lean name and convert to an identifier. --/ -def mkScopedIdent {m} [Monad m] [Lean.MonadResolveName m] (subName : Lean.Name) : m Ident := - (scopedIdent · subName) <$> getCurrNamespace +abbrev GenM := ReaderT GenContext CommandElabM -/-- Return identifier for operator with given name to suport category. -/ +def runCmd {α} (act : CommandElabM α) : GenM α := fun _ => act + +/-- Create a fresh name. -/ +def genFreshLeanName (s : String) : GenM Name := do + let fresh ← modifyGet fun s => (s.nextMacroScope, { s with nextMacroScope := s.nextMacroScope + 1 }) + let n : Name := .anonymous |>.str s + return Lean.addMacroScope (← getEnv).mainModule n fresh + +/-- Create a fresh name. -/ +def genFreshIdentPair (s : String) : GenM (Ident × Ident) := do + let name ← genFreshLeanName s + let src := (←read).src + return (mkIdentFrom (canonical := true) src name, mkIdentFrom src name) + +/-- Create a identifier from a name. -/ +def genIdentFrom (name : Name) (canonical : Bool := false) : GenM Ident := do + return mkIdentFrom (←read).src name canonical + +/-- Return identifier for operator with given + name to Lean name. -/ def getCategoryScopedName (cat : QualifiedIdent) : GenM Name := do match (←read).categoryNameMap[cat]? with | some catName => - return .mkSimple catName + return catName | none => return panic! s!"getCategoryScopedName given {cat}" -/-- Return identifier for type that implements given category. -/ -def getCategoryIdent (cat : QualifiedIdent) : GenM Ident := do - if let some nm := declaredCategories[cat]? then - return mkRootIdent nm - currScopedIdent (← getCategoryScopedName cat) +/-- +`gmkCatExpr annType c unwrap` returns the Lean type of the `c` with the given +`unwrap` flag and annotation type `annType`. + +This expression must have type `Type`. +-/ +def leanCatTypeExpr (annType : Lean.Expr) (c : SyntaxCat) (wrap : Bool := true) : + GenM Lean.Expr := do + let args ← c.args.attach.mapM (fun ⟨sc, _⟩ => leanCatTypeExpr annType sc) + + -- Handle polymorphic categories. + if let some nm := polymorphicBuiltinCategories[c.name]? then + assert! args.size == 1 + return annTypeExpr (mkAppN (.const nm [0]) args) annType + assert! args.size == 0 + + -- Handle declared categories + if let some nm := declaredCategories[c.name]? then + -- Check if unwrap is specified + if wrap then + return annTypeExpr (mkConst nm) annType + else + return mkConst nm -- Return unwrapped type + -- Handle base case + let relName ← getCategoryScopedName c.name + let catName := (← getCurrNamespace) ++ relName + let catType : Lean.Expr := mkConst catName + return .app catType annType +termination_by c +decreasing_by + cases c + decreasing_tactic -def getCategoryTerm (cat : QualifiedIdent) (annType : Ident) : GenM Term := do - let catIdent ← mkScopedIdent (← getCategoryScopedName cat) - return Lean.Syntax.mkApp catIdent #[annType] +/-- +`getCategoryTerm cat annType` returns the Lean term for a non-declared category. +-/ +def getCategoryTerm (cat : QualifiedIdent) (annType : Term) : GenM Term := do + let catIdent ← currScopedIdent (← getCategoryScopedName cat) + return mkApp catIdent #[annType] /-- Return identifier for operator with given name to suport category. -/ def getCategoryOpIdent (cat : QualifiedIdent) (name : Name) : GenM Ident := do currScopedIdent <| (← getCategoryScopedName cat) ++ name -partial def ppCatWithUnwrap (annType : Ident) (c : SyntaxCat) (unwrap : Bool) : GenM Term := do - let args ← c.args.mapM (ppCatWithUnwrap annType · false) - match c.name, eq : args.size with - | q`Init.CommaSepBy, 1 => - return mkCApp ``Ann #[mkCApp ``Array #[args[0]], annType] - | q`Init.Option, 1 => - return mkCApp ``Ann #[mkCApp ``Option #[args[0]], annType] - | q`Init.Seq, 1 => - return mkCApp ``Ann #[mkCApp ``Array #[args[0]], annType] - | cat, 0 => - match declaredCategories[cat]? with - | some nm => - -- Check if unwrap is specified - if unwrap && cat ∈ declaredCategories then - pure <| mkRootIdent nm -- Return unwrapped type - else - pure <| mkCApp ``Ann #[mkRootIdent nm, annType] - | none => do - getCategoryTerm cat annType - | f, _ => throwError "Unsupported {f.fullName}" - -partial def ppCat (annType : Ident) (c : SyntaxCat) : GenM Term := do - ppCatWithUnwrap annType c false +/-- +`ppCat annType c unwrap` returns the Lean type of the `c` with the given +`unwrap` flag and annotation type `annType`. +-/ +partial def mkCatTerm (annType : Term) (c : SyntaxCat) (wrap : Bool := true) : GenM Term := do + let args ← c.args.mapM (mkCatTerm annType) + let cat := c.name + if let some tp := polymorphicBuiltinCategories[cat]? then + let isTrue _ := inferInstanceAs (Decidable (args.size = 1)) + | throwError s!"internal: {cat} expects a single argument." + return mkCApp ``Ann #[mkCApp tp #[args[0]], annType] + if args.size ≠ 0 then + throwError "internal: Expected no arguments to {cat}." + if let some nm := declaredCategories[cat]? then + -- Check if unwrap is specified + let t := mkRootIdent nm + return if wrap then mkCApp ``Ann #[t, annType] else t + getCategoryTerm cat annType def elabCommands (commands : Array Command) : CommandElabM Unit := do let messageCount := (← get).messages.unreported.size @@ -606,41 +969,36 @@ def elabCommands (commands : Array Command) : CommandElabM Unit := do match hasNewMessage with | none => pure () | some m => - logError m!"Command elaboration reported messages:\n {commands}\n {m.kind}" + let msg := m!"Command elaboration reported messages:\nCommands:\n" + let msg := commands.foldl (init := msg) fun msg cmd => m!"{msg} {cmd}\n" + let msg := m!"{msg}Kind: {m.kind}\n" + let msg := m!"{msg}Message: {←m.data.format}" + logError msg abbrev BracketedBinder := TSyntax ``Lean.Parser.Term.bracketedBinder -def explicitBinder (name : String) (typeStx : Term) : CommandElabM BracketedBinder := do +def explicitBinder (name : String) (typeStx : Term) + : CommandElabM BracketedBinder := do let nameStx := localIdent name `(bracketedBinderF| ($nameStx : $typeStx)) def genCtor (annType : Ident) (op : DefaultCtor) : GenM (TSyntax ``ctor) := do let ctorId : Ident := localIdent op.leanNameStr + let ann ← + if op.includeAnn then do + pure #[← `(bracketedBinder| (ann : $annType))] + else + pure #[] let binders ← op.argDecls.mapM fun arg => do - explicitBinder arg.name (← ppCatWithUnwrap annType arg.cat arg.unwrap) - `(ctor| | $ctorId:ident (ann : $annType) $binders:bracketedBinder* ) + explicitBinder arg.name (← mkCatTerm annType arg.cat (wrap := arg.wrap)) + `(ctor| | $ctorId:ident $ann:bracketedBinder* $binders:bracketedBinder*) def mkInductive (cat : QualifiedIdent) (ctors : Array DefaultCtor) : GenM Command := do assert! cat ∉ declaredCategories - let ident ← mkScopedIdent (← getCategoryScopedName cat) + let ident ← currScopedIdent (← getCategoryScopedName cat) trace[Strata.generator] "Generating {ident}" let annType := localIdent "α" - let builtinCtors : Array (TSyntax ``ctor) ← - match cat with - | q`Init.Expr => do - pure #[ - ← `(ctor| | $(localIdent "fvar"):ident (ann : $annType) (idx : Nat)) - ] - | q`Init.TypeP => do - let typeIdent ← getCategoryTerm q`Init.Type annType - pure #[ - ← `(ctor| | $(localIdent "expr"):ident (tp : $typeIdent)), - ← `(ctor| | $(localIdent "type"):ident (tp : $annType)) - ] - | _ => - pure #[] `(inductive $ident ($annType : Type) : Type where - $builtinCtors:ctor* $(← ctors.mapM (genCtor annType)):ctor* deriving Repr) @@ -651,11 +1009,7 @@ def categoryToAstTypeIdent (cat : QualifiedIdent) (annType : Term) : Term := | q`Init.Type => ``Strata.TypeExprF | q`Init.TypeP => ``Strata.ArgF | _ => ``Strata.OperationF - Lean.Syntax.mkApp (mkRootIdent ident) #[annType] - -structure ToOp where - name : String - argDecls : Array (String × SyntaxCat) + mkApp (mkRootIdent ident) #[annType] def toAstIdentM (cat : QualifiedIdent) : GenM Ident := do currScopedIdent <| (← getCategoryScopedName cat) ++ `toAst @@ -664,52 +1018,41 @@ def ofAstIdentM (cat : QualifiedIdent) : GenM Ident := do currScopedIdent <| (← getCategoryScopedName cat) ++ `ofAst def mkAnnWithTerm (argCtor : Name) (annTerm v : Term) : Term := - mkCApp argCtor #[mkCApp ``Ann.ann #[annTerm], v] - -def annToAst (argCtor : Name) (annTerm : Term) : Term := - mkCApp argCtor #[mkCApp ``Ann.ann #[annTerm], mkCApp ``Ann.val #[annTerm]] + mkApp (mkCIdent argCtor) #[mkCApp ``Ann.ann #[annTerm], v] -partial def toAstApplyArg (vn : Name) (cat : SyntaxCat) (unwrap : Bool := false) : GenM Term := do +def annToAst' (argCtor : Name) (term : Term) (wrap : Bool) : Term := + if wrap then + mkAnnWithTerm argCtor term (mkCApp ``Ann.val #[term]) + else + mkApp (mkCIdent argCtor) #[mkCApp ``default #[], term] + +partial def annArg (c : SyntaxCat) (wrap : Bool) : GenM Ident := do + let cat := c.name + if cat ∈ polymorphicBuiltinCategories then + assert! c.args.size == 1 + return mkIdentFrom (←read).src ``Ann.ann + assert! c.args.size == 0 + if cat ∈ declaredCategories then + assert! wrap + return mkIdentFrom (←read).src ``Ann.ann + getCategoryOpIdent cat `ann + +partial def toAstApplyArg (vn : Name) (cat : SyntaxCat) (wrap : Bool := true) : + GenM Term := do let v := mkIdentFrom (←read).src vn match cat.name with | q`Init.Num => - if unwrap then - ``(ArgF.num default $v) - else - return annToAst ``ArgF.num v + return annToAst' ``ArgF.num v (wrap := wrap) | q`Init.Bool => do - if unwrap then - -- When unwrapped, v is a plain Bool. Create OperationF directly based on the value. - let defaultAnn ← ``(default) - let emptyArray ← ``(#[]) - let trueOp := mkCApp ``OperationF.mk #[defaultAnn, quote q`Init.boolTrue, emptyArray] - let falseOp := mkCApp ``OperationF.mk #[defaultAnn, quote q`Init.boolFalse, emptyArray] - let opExpr ← ``(if $v then $trueOp else $falseOp) - ``(ArgF.op $opExpr) - else - -- When wrapped, v is already Ann Bool α - let boolToAst := mkCApp ``Strata.Bool.toAst #[v] - return mkCApp ``ArgF.op #[boolToAst] + return mkCApp ``ArgF.op #[annToAst' ``OperationF.ofBool v (wrap := wrap)] | q`Init.Ident => - if unwrap then - ``(ArgF.ident default $v) - else - return annToAst ``ArgF.ident v + return annToAst' ``ArgF.ident v (wrap := wrap) | q`Init.Str => - if unwrap then - ``(ArgF.strlit default $v) - else - return annToAst ``ArgF.strlit v + return annToAst' ``ArgF.strlit v (wrap := wrap) | q`Init.Decimal => - if unwrap then - ``(ArgF.decimal default $v) - else - return annToAst ``ArgF.decimal v + return annToAst' ``ArgF.decimal v (wrap := wrap) | q`Init.ByteArray => - if unwrap then - ``(ArgF.bytes default $v) - else - return annToAst ``ArgF.bytes v + return annToAst' ``ArgF.bytes v (wrap := wrap) | cid@q`Init.Expr => do let toAst ← toAstIdentM cid return mkCApp ``ArgF.expr #[mkApp toAst #[v]] @@ -759,73 +1102,86 @@ partial def toAstApplyArg (vn : Name) (cat : SyntaxCat) (unwrap : Bool := false) abbrev MatchAlt := TSyntax ``Lean.Parser.Term.matchAlt -def toAstBuiltinMatches (cat : QualifiedIdent) : GenM (Array MatchAlt) := do - let src := (←read).src - match cat with - | q`Init.Expr => - let (annC, annI) ← genFreshIdentPair "ann" - let ctor ← getCategoryOpIdent cat `fvar - let pat : Term := mkApp ctor #[annC, mkCanIdent src `idx] - let rhs := mkCApp ``ExprF.fvar #[annI, mkIdentFrom src `idx] - return #[← `(matchAltExpr| | $pat => $rhs)] - | q`Init.TypeP => do - let (annC, annI) ← genFreshIdentPair "ann" - let typeC ← getCategoryOpIdent cat `type - let typeP : Term := mkApp typeC #[annC] - let typeCat := Lean.Syntax.mkCApp ``SyntaxCatF.atom #[annI, quote q`Init.Type] - let typeRhs := Lean.Syntax.mkCApp ``ArgF.cat #[typeCat] - let typeN ← genFreshLeanName "type" - let exprP := mkApp (← getCategoryOpIdent cat `expr) #[mkCanIdent src typeN] - let exprRhs ← toAstApplyArg typeN (.atom .none q`Init.Type) - return #[ - ← `(matchAltExpr| | $typeP => $typeRhs), - ← `(matchAltExpr| | $exprP => $exprRhs) - ] - | _ => - return #[] +def toAstExprMatch (op : DefaultCtor) (annT : Term) (args : Array CatOpArg) (names : Vector Name args.size) : GenM Term := do + let lname := op.leanNameStr + if lname == "fvar" then + let .isTrue arg_size_eq := inferInstanceAs (Decidable (args.size = 1)) + | return panic! s!"fvar expected 1 argument" + let src := (←read).src + return mkCApp ``ExprF.fvar #[annT, mkIdentFrom src names[0]] + let some nm := op.strataName + | return panic! s!"Unexpected builtin expression {lname}" + let init := mkCApp ``ExprF.fn #[annT, quote nm] + Fin.foldlM args.size (init := init) fun a i => do + let nm := names[i] + let d := args[i] + let e ← toAstApplyArg nm d.cat (wrap := d.wrap) + return Lean.Syntax.mkCApp ``ExprF.app #[annT, a, e] def toAstMatch (cat : QualifiedIdent) (op : DefaultCtor) : GenM MatchAlt := do let src := (←read).src let argDecls := op.argDecls - let (annC, annI) ← genFreshIdentPair "ann" let ctor : Ident ← getCategoryOpIdent cat op.leanName - let args ← argDecls.mapM fun arg => do - return (← genFreshLeanName arg.name, arg.cat, arg.unwrap) - let argTerms : Array Term := args.map fun p => mkCanIdent src p.fst - let pat : Term ← ``($ctor $annC $argTerms:term*) + let argc := argDecls.size + let argNames : Vector Name argc ← Vector.ofFnM fun (i : Fin argc) => + genFreshLeanName argDecls[i].name + let ((patArgs, annT) : Array Term × Term) ← + if h : op.includeAnn then + let (annC, annI) ← genFreshIdentPair "ann" + pure (#[(annC : Term)], (annI : Term)) + else + let argc1 : op.argDecls.size = 1 := by + have inv := op.includeAnnInvariant + grind + let d : CatOpArg := op.argDecls[0] + let annF : Ident ← annArg d.cat (wrap := d.wrap) + pure (#[], mkApp annF #[mkIdentFrom src argNames[0]]) + let pat := + let argTerms : Array Ident := argNames.map (mkCanIdent src) |>.toArray + mkApp ctor (patArgs ++ argTerms) let rhs : Term ← match cat with | q`Init.Expr => - let lname := op.leanNameStr - let some nm := op.strataName - | return panic! s!"Unexpected builtin expression {lname}" - let init := mkCApp ``ExprF.fn #[annI, quote nm] - args.foldlM (init := init) fun a (nm, tp, unwrap) => do - let e ← toAstApplyArg nm tp unwrap - return Lean.Syntax.mkCApp ``ExprF.app #[annI, a, e] + toAstExprMatch op annT argDecls argNames | q`Init.Type => do let some nm := op.strataName | return panic! "Expected type name" let toAst ← toAstIdentM cat - let argTerms ← arrayLit <| args.map fun (v, c, _unwrap) => - assert! c.isType - Lean.Syntax.mkApp toAst #[mkIdentFrom src v] - pure <| Lean.Syntax.mkCApp ``TypeExprF.ident #[annI, quote nm, argTerms] + let argTerms ← arrayLit <| Array.ofFn fun (i : Fin argc) => + assert! argDecls[i].cat.isType + mkApp toAst #[mkIdentFrom src argNames[i]] + pure <| mkApp (mkCIdent ``TypeExprF.ident) #[annT, quote nm, argTerms] + | q`Init.TypeP => do + match op.leanNameStr with + | "expr" => + let toAst ← toAstIdentM q`Init.Type + let .isTrue p := inferInstanceAs (Decidable (argc = 1)) + | return panic! "Expected one argument." + assert! argDecls[0].cat.isType + let a := mkApp toAst #[mkIdentFrom src argNames[0]] + pure <| mkCApp ``ArgF.type #[a] + | "type" => + let c := mkCApp ``SyntaxCatF.atom #[annT, quote q`Init.Type] + pure <| mkCApp ``ArgF.cat #[c] + | _ => + return panic! "Unknown typeP op" | _ => let mName ← match op.strataName with | some n => pure n - | none => throwError s!"Internal: Operation requires strata name" - let argTerms : Array Term ← args.mapM fun (nm, tp, unwrap) => toAstApplyArg nm tp unwrap - pure <| mkCApp ``OperationF.mk #[annI, quote mName, ← arrayLit argTerms] + | none => throwError s!"Internal: Operation {op.leanName} in {cat} requires strata name" + let argTerms : Array Term ← Array.ofFnM fun (i : Fin argc) => + let nm := argNames[i] + let d := argDecls[i] + toAstApplyArg nm d.cat (wrap := d.wrap) + pure <| mkCApp ``OperationF.mk #[annT, quote mName, ← arrayLit argTerms] `(matchAltExpr| | $pat => $rhs) def genToAst (cat : QualifiedIdent) (ops : Array DefaultCtor) : GenM Command := do let annType := localIdent "α" let catTerm ← getCategoryTerm cat annType let astType : Term := categoryToAstTypeIdent cat annType - let cases ← toAstBuiltinMatches cat - let cases : Array MatchAlt ← ops.mapM_off (init := cases) (toAstMatch cat) + let cases : Array MatchAlt ← ops.mapM (toAstMatch cat) let toAst ← toAstIdentM cat trace[Strata.generator] "Generating {toAst}" let src := (←read).src @@ -833,69 +1189,35 @@ def genToAst (cat : QualifiedIdent) (ops : Array DefaultCtor) : GenM Command := `(partial def $toAst {$annType : Type} [Inhabited $annType] ($(mkCanIdent src v) : $catTerm) : $astType := match $(mkIdentFrom src v):ident with $cases:matchAlt*) -mutual - -partial def getOfIdentArg (varName : String) (cat : SyntaxCat) (e : Term) : GenM Term := do - getOfIdentArgWithUnwrap varName cat false e +def addAnn (act : Name) (e : Term) (wrap : Bool) : Term := + let t := mkApp (mkCIdent act) #[e] + if wrap then + mkCApp ``Functor.map #[mkCApp ``Ann.mk #[mkCApp ``ArgF.ann #[e]], t] + else + t -partial def getOfIdentArgWithUnwrap (varName : String) (cat : SyntaxCat) (unwrap : Bool) (e : Term) : GenM Term := do +partial def getOfIdentArg (varName : String) (cat : SyntaxCat) (e : Term) (wrap : Bool := true) : GenM Term := do match cat.name with | q`Init.Num => - if unwrap then - ``((fun arg => match arg with - | ArgF.num _ val => pure val - | a => OfAstM.throwExpected "numeric literal" a) $e) - else - ``(OfAstM.ofNumM $e) + return addAnn ``OfAstM.ofNumM e (wrap := wrap) | q`Init.Ident => - if unwrap then - ``((fun arg => match arg with - | ArgF.ident _ val => pure val - | a => OfAstM.throwExpected "identifier" a) $e) - else - ``(OfAstM.ofIdentM $e) + return addAnn ``OfAstM.ofIdentM e (wrap := wrap) | q`Init.Str => - if unwrap then - ``((fun arg => match arg with - | ArgF.strlit _ val => pure val - | a => OfAstM.throwExpected "string literal" a) $e) - else - ``(OfAstM.ofStrlitM $e) + return addAnn ``OfAstM.ofStrlitM e (wrap := wrap) | q`Init.Decimal => - if unwrap then - ``((fun arg => match arg with - | ArgF.decimal _ val => pure val - | a => OfAstM.throwExpected "decimal literal" a) $e) - else - ``(OfAstM.ofDecimalM $e) + return addAnn ``OfAstM.ofDecimalM e (wrap := wrap) | q`Init.ByteArray => - if unwrap then - ``((fun arg => match arg with - | ArgF.bytes _ val => pure val - | a => OfAstM.throwExpected "byte array" a) $e) - else - ``(OfAstM.ofBytesM $e) + return addAnn ``OfAstM.ofBytesM e (wrap := wrap) | q`Init.Bool => do - if unwrap then - -- When unwrapped, extract just the Bool value from Ann Bool α - ``((fun arg => match arg with - | ArgF.op op => Functor.map Ann.val (Strata.Bool.ofAst op) - | a => OfAstM.throwExpected "boolean" a) $e) - else - let (vc, vi) ← genFreshIdentPair varName - let boolOfAst := mkCApp ``Strata.Bool.ofAst #[vi] - ``(OfAstM.ofOperationM $e fun $vc _ => $boolOfAst) + return addAnn ``Strata.Bool.ofAst e (wrap := wrap) | cid@q`Init.Expr => do - let (vc, vi) ← genFreshIdentPair <| varName ++ "_inner" let ofAst ← ofAstIdentM cid - ``(OfAstM.ofExpressionM $e fun $vc _ => $ofAst $vi) + let (vc, vi) ← genFreshIdentPair <| varName ++ "_inner" + return mkCApp ``OfAstM.ofExpressionM #[e, ←``(fun $vc _ => $ofAst $vi)] | cid@q`Init.Type => do - let (vc, vi) ← genFreshIdentPair varName - let ofAst ← ofAstIdentM cid - ``(OfAstM.ofTypeM $e fun $vc _ => $ofAst $vi) - | cid@q`Init.TypeP => do let ofAst ← ofAstIdentM cid - pure <| mkApp ofAst #[e] + let (vc, vi) ← genFreshIdentPair varName + return mkCApp ``OfAstM.ofTypeM #[e, ←``(fun $vc _ => $ofAst $vi)] | q`Init.CommaSepBy => do let c := cat.args[0]! let (vc, vi) ← genFreshIdentPair varName @@ -911,21 +1233,22 @@ partial def getOfIdentArgWithUnwrap (varName : String) (cat : SyntaxCat) (unwrap let (vc, vi) ← genFreshIdentPair varName let body ← getOfIdentArg "e" c vi ``(OfAstM.ofSeqM $e fun $vc _ => $body) + | cid@q`Init.TypeP => do + let ofAst ← ofAstIdentM cid + pure <| mkApp ofAst #[e] | cid => do assert! cat.args.isEmpty let (vc, vi) ← genFreshIdentPair varName let ofAst ← ofAstIdentM cid ``(OfAstM.ofOperationM $e fun $vc _ => $ofAst $vi) -end - -def ofAstArgs (argDecls : Array GenArgDecl) (argsVar : Ident) : GenM (Array Ident × Array (TSyntax ``doSeqItem)) := do +def ofAstArgs (argDecls : Array CatOpArg) (argsVar : Ident) : GenM (Array Ident × Array (TSyntax ``doSeqItem)) := do let argCount := argDecls.size let args ← Array.ofFnM (n := argCount) fun ⟨i, _isLt⟩ => do let arg := argDecls[i] let (vc, vi) ← genFreshIdentPair <| arg.name ++ "_bind" let av ← ``($argsVar[$(quote i)]) - let rhs ← getOfIdentArgWithUnwrap arg.name arg.cat arg.unwrap av + let rhs ← getOfIdentArg arg.name arg.cat av (wrap := arg.wrap) let stmt ← `(doSeqItem| let $vc ← $rhs:term) return (vi, stmt) return args.unzip @@ -937,7 +1260,7 @@ def ofAstMatch (nameIndexMap : Std.HashMap QualifiedIdent Nat) (op : DefaultCtor | return panic! s!"Unbound operator name {name}" `(matchAltExpr| | Option.some $(quote nameIndex) => $rhs) -def ofAstExprMatchRhs (cat : QualifiedIdent) (annI argsVar : Ident) (op : DefaultCtor) : GenM Term:= do +def ofAstExprMatchRhs (cat : QualifiedIdent) (annI argsVar : Ident) (op : DefaultCtor) : GenM Term := do let ctorIdent ← getCategoryOpIdent cat op.leanName let some nm := op.strataName | return panic! s!"Missing name for {op.leanName}" @@ -955,7 +1278,7 @@ def ofAstExprMatch (nameIndexMap : Std.HashMap QualifiedIdent Nat) let rhs ← ofAstExprMatchRhs cat annI argsVar op ofAstMatch nameIndexMap op rhs -def ofAstTypeArgs (argDecls : Array GenArgDecl) (argsVar : Ident) : GenM (Array Ident × Array (TSyntax ``doSeqItem)) := do +def ofAstTypeArgs (argDecls : Array CatOpArg) (argsVar : Ident) : GenM (Array Ident × Array (TSyntax ``doSeqItem)) := do let argCount := argDecls.size let ofAst ← ofAstIdentM q`Init.Type let args ← Array.ofFnM (n := argCount) fun ⟨i, _isLt⟩ => do @@ -972,10 +1295,9 @@ def ofAstTypeMatchRhs (cat : QualifiedIdent) (ann argsVar : Ident) (op : Default let argDecls := op.argDecls let (parsedArgs, stmts) ← ofAstTypeArgs argDecls argsVar let checkExpr ← ``(OfAstM.checkTypeArgCount $(quote argDecls.size) $(argsVar)) - `(do - let .up p ← $checkExpr:term - $stmts:doSeqItem* - pure <| $ctorIdent $ann $parsedArgs:term*) + `(do let .up p ← $checkExpr:term + $stmts:doSeqItem* + pure <| $ctorIdent $ann $parsedArgs:term*) def ofAstOpMatchRhs (cat : QualifiedIdent) (annI argsVar : Ident) (op : DefaultCtor) : GenM Term := do let some name := op.strataName @@ -994,27 +1316,24 @@ def ofAstOpMatchRhs (cat : QualifiedIdent) (annI argsVar : Ident) (op : DefaultC Creates a mapping from operation names (QualifiedIdent) to unique natural numbers. This is used to pattern match in the generated code. -/ -def createNameIndexMap (cat : QualifiedIdent) (ops : Array DefaultCtor) : GenM (Std.HashMap QualifiedIdent Nat × Ident × Command) := do +def createNameIndexMap (cat : QualifiedIdent) (ops : Array DefaultCtor) : + GenM (Std.HashMap QualifiedIdent Nat × Ident × Command) := do let nameIndexMap := ops.foldl (init := {}) fun map op => match op.strataName with | none => map -- Skip operators without a name | some name => map.insert name map.size -- Assign the next available index - let ofAstNameMap ← currScopedIdent <| (← getCategoryScopedName cat) ++ `ofAst.nameIndexMap + let ofAstNameMap ← + currScopedIdent <| (← getCategoryScopedName cat) ++ `ofAst.nameIndexMap let cmd ← `(def $ofAstNameMap : Std.HashMap Strata.QualifiedIdent Nat := Std.HashMap.ofList $(quote nameIndexMap.toList)) pure (nameIndexMap, ofAstNameMap, cmd) -def mkOfAstDef (cat : QualifiedIdent) (ofAst : Ident) (v : Name) (rhs : Term) : GenM Command := do +def mkOfAstDef (cat : QualifiedIdent) (ofAst : Ident) (v : Name) (rhs : Term) : + GenM Command := do let src := (←read).src let annType := localIdent "α" let catTerm ← getCategoryTerm cat annType `(partial def $ofAst {$annType : Type} [Inhabited $annType] [Repr $annType] ($(mkCanIdent src v) : $(categoryToAstTypeIdent cat annType)) : OfAstM $catTerm := $rhs) -def matchTypeParamOrType {Ann α} [Repr Ann] (a : ArgF Ann) (onTypeParam : Ann → α) (onType : TypeExprF Ann → OfAstM α) : OfAstM α := - match a with - | .cat (.atom ann q`Init.Type) => pure (onTypeParam ann) - | .type tp => onType tp - | _ => .throwExpected "Type parameter or type expression" a - def genOfAst (cat : QualifiedIdent) (ops : Array DefaultCtor) : GenM (Array Command × Command) := do let src := (←read).src let ofAst ← ofAstIdentM cat @@ -1026,7 +1345,11 @@ def genOfAst (cat : QualifiedIdent) (ops : Array DefaultCtor) : GenM (Array Comm let (annC, annI) ← genFreshIdentPair "ann" let (nameIndexMap, ofAstNameMap, cmd) ← createNameIndexMap cat ops let fvarCtorIdent ← getCategoryOpIdent cat `fvar - let cases : Array MatchAlt ← ops.mapM (ofAstExprMatch nameIndexMap cat annI (mkIdentFrom src argsVar)) + let cases : Array MatchAlt ← ops.filterMapM fun op => + if op.leanNameStr == "fvar" then + pure none + else + some <$> ofAstExprMatch nameIndexMap cat annI (mkIdentFrom src argsVar) op let rhs ← `(let vnf := ($(mkIdentFrom src v)).hnf let $(mkCanIdent src argsVar) := vnf.args.val @@ -1059,7 +1382,7 @@ def genOfAst (cat : QualifiedIdent) (ops : Array DefaultCtor) : GenM (Array Comm let exprCtorIdent ← getCategoryOpIdent cat `expr let typeOfAst ← ofAstIdentM q`Init.Type let rhs ← ``( - matchTypeParamOrType $(mkIdentFrom src v) $catCtorIdent (fun tp => $exprCtorIdent <$> $typeOfAst tp) + Strata.OfAstM.matchTypeParamOrType $(mkIdentFrom src v) $catCtorIdent (fun tp => $exprCtorIdent <$> $typeOfAst tp) ) pure (#[], ← mkOfAstDef cat ofAst v rhs) | _ => @@ -1080,7 +1403,8 @@ def genOfAst (cat : QualifiedIdent) (ops : Array DefaultCtor) : GenM (Array Comm abbrev InhabitedSet := Std.HashSet QualifiedIdent -def checkInhabited (cat : QualifiedIdent) (ops : Array DefaultCtor) : StateT InhabitedSet GenM Unit := do +def checkInhabited (cat : QualifiedIdent) (ops : Array DefaultCtor) : + StateT InhabitedSet GenM Unit := do if cat ∈ (←get) then return () let annType := localIdent "α" @@ -1097,7 +1421,9 @@ def checkInhabited (cat : QualifiedIdent) (ops : Array DefaultCtor) : StateT Inh continue let ctor : Term ← getCategoryOpIdent cat op.leanName let d := Lean.mkCIdent ``default - let e := Lean.Syntax.mkApp ctor (Array.replicate (op.argDecls.size + 1) d) + let argc := if op.includeAnn then 1 else 0 + let argc := argc + op.argDecls.size + let e := mkApp ctor (Array.replicate argc d) StateT.lift <| runCmd <| elabCommand =<< `(instance [Inhabited $annType] : Inhabited $catTerm where default := $e) modify (·.insert cat) @@ -1114,11 +1440,84 @@ partial def addInhabited (group : Array (QualifiedIdent × Array DefaultCtor)) ( else pure sm -def gen (categories : Array (QualifiedIdent × Array DefaultCtor)) : GenM Unit := do +/-- +Given a category and an unwrap tag, this returns a Lean expression +which given a value in the category with the wrap parameter, returns +the annotation value. +-/ +partial def annExpr (c : SyntaxCat) (wrap : Bool) (annType : Lean.Expr) : GenM Lean.Expr := do + let cat := c.name + if let some name := polymorphicBuiltinCategories[cat]? then + assert! c.args.size == 1 + let baseType ← leanCatTypeExpr annType c.args[0]! + let baseType := Lean.mkApp (.const name [0]) baseType + return mkApp2 (.const ``Ann.ann []) baseType annType + if let some name := declaredCategories[cat]? then + assert! c.args.size == 0 + assert! wrap + return mkApp2 (.const ``Ann.ann []) (.const name []) annType + + assert! c.args.size == 0 + match (←read).categoryNameMap[cat]? with + | some catName => + return Lean.mkApp (.const ((← getCurrNamespace) ++ catName ++ `ann) []) annType + | none => + return panic! s!"annExpr given {cat}" + +def annRecursor (c : DefaultCtor) : GenM Lean.Expr := do + let argc := c.argDecls.size + let (inner_off, ann) ← + if h : c.includeAnn then + pure (2, Lean.Expr.bvar argc) + else + have ne : c.argDecls.size > 0 := by + have p := c.includeAnnInvariant + grind + let d := c.argDecls[0] + let annFn ← annExpr d.cat (wrap := d.wrap) (.bvar (argc+1)) + pure (1, Lean.mkApp annFn (.bvar (argc - 1))) + let inner : Lean.Expr ← Fin.foldrM argc (init := ann) fun i e => do + let a := c.argDecls[i] + let argType ← leanCatTypeExpr (.bvar (inner_off + i)) a.cat (wrap := a.wrap) + return .lam (.mkSimple a.name) argType (binderInfo := .default) e + if c.includeAnn then + return .lam `ann (.bvar 1) (binderInfo := .default) inner + else + return inner + +def genAnnFunctions (cat : QualifiedIdent) (ctors : Array DefaultCtor) : GenM Unit := do + let relName ← getCategoryScopedName cat + + let catName := (← getCurrNamespace) ++ relName + let catType : Lean.Expr := mkConst catName + let defName := catName ++ `ann + let type : Lean.Expr := + .forallE `α (.sort 1) (binderInfo := .implicit) <| + .forallE `_ (.app catType (.bvar 0)) (binderInfo := .default) <| + .bvar 1 + let motive : Lean.Expr := .lam `_ (.app catType (.bvar 1)) (binderInfo := .default) (.bvar 2) + let term : Lean.Expr := mkApp3 (.const (catName ++ `casesOn) [1]) (.bvar 1) motive (.bvar 0) + let term : Lean.Expr ← ctors.foldlM (init := term) fun f c => + return .app f (← annRecursor c) + let value : Lean.Expr := + .lam `α (.sort 1) (binderInfo := .implicit) <| + .lam `a (.app catType (.bvar 0)) (binderInfo := .default) <| + term + liftCoreM <| addAndCompile <| .defnDecl { + name := defName + levelParams := [] + type := type + value := value + hints := .opaque + safety := .safe + all := [defName] + } + +def gen (categories : Array (Array (QualifiedIdent × Array DefaultCtor))) : GenM Unit := do let mut inhabitedCats : InhabitedSet := Std.HashSet.ofArray declaredCategories.keysArray - for allCtors in orderedSyncatGroups categories do + for allCtors in categories do let s ← withTraceNode `Strata.generator (fun _ => return m!"Declarations group: {allCtors.map (·.fst)}") do @@ -1135,14 +1534,15 @@ def gen (categories : Array (QualifiedIdent × Array DefaultCtor)) : GenM Unit : let cats := allCtors.map (·.fst) profileitM Lean.Exception s!"Generating inductives {cats}" (← getOptions) do let inductives ← allCtors.mapM fun (cat, ctors) => do - assert! q`Init.Num ≠ cat - assert! q`Init.Str ≠ cat mkInductive cat ctors runCmd <| elabCommands inductives let inhabitedCats2 ← profileitM Lean.Exception s!"Generating inhabited {cats}" (← getOptions) do addInhabited allCtors inhabitedCats let inhabitedCats := inhabitedCats2 + profileitM Lean.Exception s!"Generating ann functions {cats}" (← getOptions) do + allCtors.forM fun (cat, ctors) => do + genAnnFunctions cat ctors profileitM Lean.Exception s!"Generating toAstDefs {cats}" (← getOptions) do let toAstDefs ← allCtors.mapM fun (cat, ctors) => do genToAst cat ctors @@ -1156,23 +1556,13 @@ def gen (categories : Array (QualifiedIdent × Array DefaultCtor)) : GenM Unit : pure inhabitedCats inhabitedCats := s -def runGenM (src : Lean.Syntax) (pref : String) (catNames : Array QualifiedIdent) (exprHasEta : Bool) (m : GenM α) : CommandElabM α := do - let catNameCounts : Std.HashMap String Nat := - catNames.foldl (init := {}) fun m k => - m.alter k.name (fun v => some (v.getD 0 + 1)) - let categoryNameMap := catNames.foldl (init := {}) fun m i => - let name := - if catNameCounts.getD i.name 0 > 1 then - s!"{i.dialect}_{i.name}" - else if i.name ∈ reservedCats then - s!"{pref}{i.name}" - else - i.name - m.insert i name +def runGenM {α} (src : Lean.Syntax) + (categoryNameMap : Std.HashMap QualifiedIdent Lean.Name) + (m : GenM α) + : CommandElabM α := do let ctx : GenContext := { src := src - categoryNameMap := categoryNameMap - exprHasEta := exprHasEta + categoryNameMap } m ctx @@ -1184,30 +1574,75 @@ and back. -/ syntax (name := strataGenCmd) "#strata_gen" ident : command -- declare the syntax +/-- +Create a map from names of categories to their Lean name. +-/ +def mkCatNameMap (pref : String) (cm : CatIndexMap) (em : PartialNodeMapping cm.size) + : Std.HashMap QualifiedIdent Name := + -- Get number + let catNameCounts : Std.HashMap String Nat := + em.targetSize.fold (init := {}) fun tgtNode tgtlt m => + let i := cm.idents[em.sourceNode tgtNode] + m.alter i.name (fun v => some (v.getD 0 + 1)) + let init := .emptyWithCapacity em.targetSize + em.targetSize.fold (init := init) fun tgtNode tgtlt m => + let i := cm.idents[em.sourceNode tgtNode] + let name := + if catNameCounts.getD i.name 0 > 1 then + .mkSimple s!"{i.dialect}_{i.name}" + else if i.name ∈ reservedCatNames then + .mkSimple s!"{pref}{i.name}" + else + .mkSimple i.name + m.insert i name + @[command_elab strataGenCmd] -def genAstImpl : CommandElab := fun stx => +public def genAstImpl : CommandElab := fun stx => match stx with | `(#strata_gen $dialectStx) => do let .str .anonymous dialectName := dialectStx.getId | throwErrorAt dialectStx s!"Expected dialect name" let loader := dialectExt.getState (← getEnv) |>.loaded - let depDialectNames := generateDependentDialects (loader.dialects[·]?) dialectName + let depDialectNames := computeImportedDialects loader.dialects dialectName let usedDialects ← depDialectNames.mapM fun nm => match loader.dialects[nm]? with | some d => pure d | none => panic! s!"Missing dialect {nm}" let some d := loader.dialects[dialectName]? | throwErrorAt dialectStx "Missing dialect" - let (cm, errs) := mkCatOpMap usedDialects + let (cm, errs) := mkCatIndexMap usedDialects if errs.size > 0 then for e in errs do logError e return + let errors := checkDialect cm d + if errors.size > 0 then + for e in errs do + logError e + return + + let g := indexMapToGraph cm + -- Get all categories that are introduced or modified by d + let relevantCategories := newCategories cm d + -- Compute the closure of all categories needed for relevant categories. + let usedSet := OutGraph.preimage_closure g relevantCategories let exprHasEta := false -- FIXME - let cats := cm.onlyUsedCategories d exprHasEta - let catNames := cats.map (·.fst) - runGenM stx dialectName catNames exprHasEta (gen cats) + +-- let count : Nat := usedSet.foldl (b := 0) fun c e => if e then c + 1 else c + + + let em : Graph.PartialNodeMapping cm.size := .fromVector usedSet + let g := g.projection em.embedFn + + let mutual_indices : Array (Array (OutGraph.Node em.targetSize)) := OutGraph.tarjan g + let mutual_groups := mutual_indices.map fun a => a.map fun j => + ctors exprHasEta cm (em.sourceNode j.val) + + let categoryNameMap := mkCatNameMap dialectName cm em + runGenM stx categoryNameMap (gen mutual_groups) | _ => throwUnsupportedSyntax end Strata + +end diff --git a/Strata/DDM/Integration/Lean/HashCommands.lean b/Strata/DDM/Integration/Lean/HashCommands.lean index addac2f96..cee6b8a4e 100644 --- a/Strata/DDM/Integration/Lean/HashCommands.lean +++ b/Strata/DDM/Integration/Lean/HashCommands.lean @@ -5,17 +5,12 @@ -/ module -public import Lean.Parser.Types - public import Lean.Elab.Command - -public meta import Strata.DDM.Integration.Lean.ToExpr -import Strata.DDM.TaggedRegions - -public meta import Strata.DDM.Integration.Lean.Env +public import Lean.Parser.Types public meta import Strata.DDM.Elab - -meta import Strata.DDM.TaggedRegions +public meta import Strata.DDM.Integration.Lean.Env +public meta import Strata.DDM.Integration.Lean.ToExpr +public meta import Strata.DDM.TaggedRegions open Lean open Lean.Elab (throwUnsupportedSyntax) @@ -24,6 +19,7 @@ open Lean.Elab.Term (TermElab) open Lean.Parser (InputContext) open System (FilePath) +public meta section namespace Strata class HasInputContext (m : Type → Type _) [Functor m] where @@ -31,9 +27,7 @@ class HasInputContext (m : Type → Type _) [Functor m] where getFileName : m FilePath := (fun ctx => FilePath.mk ctx.fileName) <$> getInputContext ---export HasInputContext (getInputContext) - -meta instance : HasInputContext CommandElabM where +private instance : HasInputContext CommandElabM where getInputContext := do let ctx ← read pure { @@ -43,7 +37,7 @@ meta instance : HasInputContext CommandElabM where } getFileName := return (← read).fileName -meta instance : HasInputContext CoreM where +private instance : HasInputContext CoreM where getInputContext := do let ctx ← read pure { @@ -53,7 +47,7 @@ meta instance : HasInputContext CoreM where } getFileName := return (← read).fileName -private meta def mkScopedName {m} [Monad m] [MonadError m] [MonadEnv m] [MonadResolveName m] (name : Name) : m Name := do +private def mkScopedName {m} [Monad m] [MonadError m] [MonadEnv m] [MonadResolveName m] (name : Name) : m Name := do let scope ← getCurrNamespace let fullName := scope ++ name let env ← getEnv @@ -61,17 +55,10 @@ private meta def mkScopedName {m} [Monad m] [MonadError m] [MonadEnv m] [MonadRe throwError s!"Cannot define {name}: {fullName} already exists." return fullName -/-- -Prepend the current namespace to the Lean name and convert to an identifier. --/ -private def mkAbsIdent (name : Lean.Name) : Ident := - let nameStr := toString name - .mk (.ident .none nameStr.toSubstring name [.decl name []]) - /-- Add a definition to environment and compile it. -/ -meta def addDefn (name : Lean.Name) +private def addDefn (name : Lean.Name) (type : Lean.Expr) (value : Lean.Expr) (levelParams : List Name := []) @@ -91,7 +78,7 @@ meta def addDefn (name : Lean.Name) /-- Declare dialect and add to environment. -/ -public meta def declareDialect (d : Dialect) : CommandElabM Unit := do +def declareDialect (d : Dialect) : CommandElabM Unit := do -- Identifier for dialect let dialectName := Name.anonymous |>.str d.name let dialectAbsName ← mkScopedName dialectName @@ -120,7 +107,7 @@ public meta def declareDialect (d : Dialect) : CommandElabM Unit := do declare_tagged_region command strataDialectCommand "#dialect" "#end" @[command_elab strataDialectCommand] -public meta def strataDialectImpl: Lean.Elab.Command.CommandElab := fun (stx : Syntax) => do +def strataDialectImpl: CommandElab := fun (stx : Syntax) => do let .atom i v := stx[1] | throwError s!"Bad {stx[1]}" let .original _ p _ e := i @@ -138,7 +125,7 @@ public meta def strataDialectImpl: Lean.Elab.Command.CommandElab := fun (stx : S declare_tagged_region term strataProgram "#strata" "#end" @[term_elab strataProgram] -public meta def strataProgramImpl : TermElab := fun stx tp => do +meta def strataProgramImpl : TermElab := fun stx tp => do let .atom i v := stx[1] | throwError s!"Bad {stx[1]}" let .original _ p _ e := i @@ -168,7 +155,7 @@ public meta def strataProgramImpl : TermElab := fun stx tp => do syntax (name := loadDialectCommand) "#load_dialect" str : command -meta def resolveLeanRelPath {m} [Monad m] [HasInputContext m] [MonadError m] (path : FilePath) : m FilePath := do +def resolveLeanRelPath {m} [Monad m] [HasInputContext m] [MonadError m] (path : FilePath) : m FilePath := do if path.isAbsolute then pure path else @@ -178,7 +165,7 @@ meta def resolveLeanRelPath {m} [Monad m] [HasInputContext m] [MonadError m] (pa pure <| leanDir / path @[command_elab loadDialectCommand] -public meta def loadDialectImpl: CommandElab := fun (stx : Syntax) => do +def loadDialectImpl: CommandElab := fun (stx : Syntax) => do match stx with | `(command|#load_dialect $pathStx) => let dialectPath : FilePath := pathStx.getString @@ -199,3 +186,4 @@ public meta def loadDialectImpl: CommandElab := fun (stx : Syntax) => do throwUnsupportedSyntax end Strata +end diff --git a/Strata/DDM/Integration/Lean/OfAstM.lean b/Strata/DDM/Integration/Lean/OfAstM.lean index 58f85370d..8ee13a3ce 100644 --- a/Strata/DDM/Integration/Lean/OfAstM.lean +++ b/Strata/DDM/Integration/Lean/OfAstM.lean @@ -6,7 +6,7 @@ module public import Strata.DDM.AST public import Strata.DDM.HNF -import Strata.DDM.Util.Array +import all Strata.DDM.Util.Array public section namespace Strata @@ -44,13 +44,13 @@ instance : Monad OfAstM := inferInstanceAs (Monad (Except _)) /-- Thrown when an expression is provided but not expected. -/ -def throwUnknownExpr (tp : QualifiedIdent) (e : Expr) : OfAstM α := +def throwUnknownExpr {α} (tp : QualifiedIdent) (e : Expr) : OfAstM α := Except.error s!"Unknown expr {repr e} when parsing as {tp}." /-- Thrown when an expression is provided but not expected. -/ -def throwUnknownType (e : TypeExpr) : OfAstM α := +def throwUnknownType {α} (e : TypeExpr) : OfAstM α := Except.error s!"Unknown type {repr e}." /-- @@ -133,26 +133,31 @@ def ofOperationM {α β} [Repr α] [SizeOf α] | .op a1 => act a1 (by decreasing_tactic) | a => .throwExpected "operation" a -def ofIdentM {α} [Repr α] : ArgF α → OfAstM (Ann String α) -| .ident ann val => pure { ann := ann, val := val } +def ofBytesM {α} [Repr α] : ArgF α → OfAstM ByteArray +| .bytes _ val => pure val +| a => .throwExpected "byte array" a + +@[inline] +def ofDecimalM {α} [Repr α] : ArgF α → OfAstM Decimal +| .decimal _ val => pure val +| a => .throwExpected "scientific literal" a + +@[inline] +def ofIdentM {α} [Repr α] : ArgF α → OfAstM String +| .ident _ val => pure val | a => .throwExpected "identifier" a -def ofNumM {α} [Repr α] : ArgF α → OfAstM (Ann Nat α) -| .num ann val => pure { ann := ann, val := val } +@[inline] +def ofNumM {α} [Repr α] : ArgF α → OfAstM Nat +| .num _ val => pure val | a => .throwExpected "numeric literal" a -def ofDecimalM {α} [Repr α] : ArgF α → OfAstM (Ann Decimal α) -| .decimal ann val => pure { ann := ann, val := val } -| a => .throwExpected "scientific literal" a -def ofStrlitM {α} [Repr α] : ArgF α → OfAstM (Ann String α) -| .strlit ann val => pure { ann := ann, val := val } +@[inline] +def ofStrlitM {α} [Repr α] : ArgF α → OfAstM String +| .strlit _ val => pure val | a => .throwExpected "string literal" a -def ofBytesM {α} [Repr α] : ArgF α → OfAstM (Ann ByteArray α) -| .bytes ann val => pure { ann := ann, val := val } -| a => .throwExpected "byte array" a - def ofOptionM {α β} [Repr α] [SizeOf α] (arg : ArgF α) (act : ∀(e : ArgF α), sizeOf e < sizeOf arg → OfAstM β) @@ -227,5 +232,16 @@ def exprEtaArg{Ann α T} [Repr Ann] [HasEta α T] {e : Expr} {n : Nat} (as : Siz let i := n - 1 - lvl return HasEta.bvar i +def matchTypeParamOrType {Ann α} [Repr Ann] (a : ArgF Ann) (onTypeParam : Ann → α) (onType : TypeExprF Ann → OfAstM α) : OfAstM α := + match a with + | .type tp => onType tp + | .cat c => + if c.name = q`Init.Type then + pure (onTypeParam c.ann) + else + .throwExpected "Type parameter or type expression" a + | _ => + .throwExpected "Type parameter or type expression" a + end Strata.OfAstM end diff --git a/Strata/DDM/Integration/Lean/ToExpr.lean b/Strata/DDM/Integration/Lean/ToExpr.lean index 9a1d96f4e..3bf2df76c 100644 --- a/Strata/DDM/Integration/Lean/ToExpr.lean +++ b/Strata/DDM/Integration/Lean/ToExpr.lean @@ -306,9 +306,6 @@ instance SynCatDecl.instToExpr : ToExpr SynCatDecl where namespace DebruijnIndex -private protected def ofNat {n : Nat} [NeZero n] (a : Nat) : DebruijnIndex n := - ⟨a % n, Nat.mod_lt _ (Nat.pos_of_neZero n)⟩ - instance {n} : ToExpr (DebruijnIndex n) where toTypeExpr := private .app (mkConst ``DebruijnIndex) (toExpr n) toExpr a := private diff --git a/Strata/DDM/Ion.lean b/Strata/DDM/Ion.lean index b691d2e29..e0c69d660 100644 --- a/Strata/DDM/Ion.lean +++ b/Strata/DDM/Ion.lean @@ -17,6 +17,16 @@ open Lean.Elab.Command open Ion +namespace Array + +def mapM_off {α β m} [Monad m] (as : Array α) (f : α → m β) + (start : Nat := 0) (stop := as.size) + (init : Array β := Array.mkEmpty ((min as.size stop) - start)) : m (Array β) := + as.foldlM (init := init) (start := start) (stop := stop) + fun r e => r.push <$> f e + +end Array + public section namespace Ion.Ion diff --git a/Strata/DDM/TaggedRegions.lean b/Strata/DDM/TaggedRegions.lean index 59ac1c5a7..8f91887e3 100644 --- a/Strata/DDM/TaggedRegions.lean +++ b/Strata/DDM/TaggedRegions.lean @@ -7,7 +7,7 @@ module import Lean.PrettyPrinter.Formatter import Lean.PrettyPrinter.Parenthesizer -import Strata.DDM.Util.String +import all Strata.DDM.Util.String public meta import Lean.Elab.Syntax diff --git a/Strata/DDM/Util/Array.lean b/Strata/DDM/Util/Array.lean index 2be7c3743..5c209cfb4 100644 --- a/Strata/DDM/Util/Array.lean +++ b/Strata/DDM/Util/Array.lean @@ -4,14 +4,14 @@ SPDX-License-Identifier: Apache-2.0 OR MIT -/ module +import all Strata.DDM.Util.List -import Strata.DDM.Util.List +set_option autoImplicit false -public section namespace Array @[simp] -theorem anyM_empty {α} [Monad m] (f : α → m Bool) (start : Nat := 0) (stop : Nat := 0) +theorem anyM_empty {α m} [Monad m] (f : α → m Bool) (start : Nat := 0) (stop : Nat := 0) : Array.anyM f #[] start stop = @pure m _ _ false := by unfold Array.anyM split @@ -23,18 +23,6 @@ theorem anyM_empty {α} [Monad m] (f : α → m Bool) (start : Nat := 0) (stop : unfold anyM.loop simp -def map_off {α β} (as : Array α) (f : α → β) - (start : Nat := 0) (stop : Nat := as.size) - (init : Array β := Array.mkEmpty ((min as.size stop) - start)) : Array β := - as.foldl (init := init) (start := start) (stop := stop) - fun r e => r.push (f e) - -def mapM_off {α β m} [Monad m] (as : Array α) (f : α → m β) - (start : Nat := 0) (stop := as.size) - (init : Array β := Array.mkEmpty ((min as.size stop) - start)) : m (Array β) := - as.foldlM (init := init) (start := start) (stop := stop) - fun r e => r.push <$> f e - private theorem extract_loop_succ_upper {α} (as b : Array α) (i j : Nat) (h : i + j < as.size) : Array.extract.loop as (i + 1) j b = (Array.extract.loop as i j b).push (as[i + j]'h) := by @@ -52,7 +40,7 @@ private theorem extract_loop_succ_upper {α} (as b : Array α) (i j : Nat) (h : have p : j + (i + 1) = j + 1 + i := by omega simp [g, hyp _ _ h, p] -private theorem extract_succ {α} (as : Array α) {i : Nat} (g : i ≤ j) (h : j < as.size) : as.extract i (j + 1) = (as.extract i j).push (as[j]'h) := by +private theorem extract_succ {α} (as : Array α) {i j : Nat} (g : i ≤ j) (h : j < as.size) : as.extract i (j + 1) = (as.extract i j).push (as[j]'h) := by have j1_le : (j + 1) ≤ as.size := by omega have j_le : j ≤ as.size := by omega have p : j + 1 - i = j - i + 1 := by omega @@ -62,7 +50,7 @@ private theorem extract_succ {α} (as : Array α) {i : Nat} (g : i ≤ j) (h : j private theorem sizeOf_toList {α} [SizeOf α] (as : Array α) : sizeOf as = 1 + sizeOf as.toList := rfl -theorem sizeOf_min [SizeOf α] (as : Array α) : sizeOf as ≥ 2 := by +theorem sizeOf_min {α} [SizeOf α] (as : Array α) : sizeOf as ≥ 2 := by have p := sizeOf_toList as have q := List.sizeOf_pos as.toList omega @@ -74,7 +62,7 @@ theorem sizeOf_push {α} [SizeOf α] (as : Array α) (a : α) : omega @[simp] -theorem sizeOf_set [SizeOf α] (a : Array α) (i : Nat) (v : α) (hi : i < a.size) : sizeOf (a.set i v) = sizeOf a - sizeOf a[i] + sizeOf v := by +theorem sizeOf_set {α} [SizeOf α] (a : Array α) (i : Nat) (v : α) (hi : i < a.size) : sizeOf (a.set i v) = sizeOf a - sizeOf a[i] + sizeOf v := by match a with | .mk l => unfold Array.set @@ -84,7 +72,7 @@ theorem sizeOf_set [SizeOf α] (a : Array α) (i : Nat) (v : α) (hi : i < a.si omega @[simp] -theorem sizeOf_swap [h : SizeOf α] (a : Array α) (i : Nat) (j : Nat) (hi : i < a.size) (hj : j < a.size) : sizeOf (a.swap i j) = sizeOf a := by +theorem sizeOf_swap {α} [h : SizeOf α] (a : Array α) (i : Nat) (j : Nat) (hi : i < a.size) (hj : j < a.size) : sizeOf (a.swap i j) = sizeOf a := by unfold Array.swap have h : sizeOf a[i] < sizeOf a := sizeOf_getElem _ _ _ simp [Array.getElem_set] @@ -109,7 +97,7 @@ theorem sizeOf_reverse {α} [SizeOf α] (a : Array α) : sizeOf a.reverse = size case isFalse p => simp [sizeOf_reverse_loop] -theorem sizeOf_lt_of_mem_strict [SizeOf α] {as : Array α} (h : a ∈ as) : sizeOf a + 3 ≤ sizeOf as := by +theorem sizeOf_lt_of_mem_strict {α} [SizeOf α] {a} {as : Array α} (h : a ∈ as) : sizeOf a + 3 ≤ sizeOf as := by cases as with | _ as => simp +arith [List.sizeOf_lt_of_mem_strict h.val] @@ -122,5 +110,48 @@ theorem of_mem_pop {α} {a : α} {as : Array α} : a ∈ as.pop → a ∈ as := simp [Array.mem_iff_getElem] grind +theorem mem_pop_ne {α} {a : α} {as : Array α} (ne : as.size > 0) : + a ∈ as ↔ a ∈ as.pop ∨ a = as.back := by + have as_ne : as ≠ #[] := Array.ne_empty_of_size_pos ne + have ne' : as.toList ≠ [] := + as_ne ∘ Array.toList_eq_nil_iff.mp + simp only [ + ← Array.mem_toList_iff, + List.mem_ne_as_dropLast _ ne', + Array.getLast_toList, + Array.toList_pop + ] + +theorem size_filter_pos {α} {p : α → Bool} {as : Array α} {i : Nat} + {h : i < as.size} (witness : p as[i] = true) : (Array.filter p as).size > 0 := by + have as_eq : as.filter p = (as.set i as[i] h).filter p := by + simp [Array.set_getElem_self] + rw [as_eq] + simp only [Array.set_eq_push_extract_append_extract] + simp only [Array.filter_append, Array.size_append, Array.filter_push] + simp [witness] + omega + +theorem size_filter_set {α} (p : α → Bool) (as : Array α) (i : Nat) (v : α) + (h : i < as.size := by get_elem_tactic) : (Array.filter p (as.set i v)).size = + (Array.filter p as).size + + (if p v then 1 else 0) + - (if p as[i] = true then 1 else 0) := by + have as_eq : as.filter p = (as.set i as[i] h).filter p := by + simp [Array.set_getElem_self] + rw [as_eq] + simp only [Array.set_eq_push_extract_append_extract] + simp only [Array.filter_append, Array.size_append, Array.filter_push] + if newp : p v = true then + if oldp : p as[i] = true then + simp [newp, oldp] + else + simp [newp, oldp] + omega + else + if oldp : p as[i] = true then + simp [newp, oldp] + else + simp [newp, oldp] + end Array -end diff --git a/Strata/DDM/Util/ByteArray.lean b/Strata/DDM/Util/ByteArray.lean index 61109a5e0..4c63faf7c 100644 --- a/Strata/DDM/Util/ByteArray.lean +++ b/Strata/DDM/Util/ByteArray.lean @@ -11,14 +11,13 @@ Functions for ByteArray that could potentially be upstreamed to Lean. import Std.Data.HashMap public import Lean.ToExpr -public section namespace ByteArray -private def back! (a : ByteArray) : UInt8 := a.get! (a.size - 1) +def back! (a : ByteArray) : UInt8 := a.get! (a.size - 1) -private def back? (a : ByteArray) : Option UInt8 := a[a.size - 1]? +def back? (a : ByteArray) : Option UInt8 := a[a.size - 1]? -private def pop (a : ByteArray) : ByteArray := a.extract 0 (a.size - 1) +def pop (a : ByteArray) : ByteArray := a.extract 0 (a.size - 1) @[inline] def foldr {β} (f : UInt8 → β → β) (init : β) (as : ByteArray) (start := as.size) (stop := 0) : β := @@ -30,25 +29,12 @@ def foldr {β} (f : UInt8 → β → β) (init : β) (as : ByteArray) (start := aux (i-1) (by omega) (f as[i-1] b) aux (min start as.size) (Nat.min_le_right _ _) init -private def byteToHex (b : UInt8) : String := - let cl : String := .ofList (Nat.toDigits 16 b.toNat) - if cl.length < 2 then "0" ++ cl else cl - -def asHex (a : ByteArray) : String := - a.foldl (init := "") fun s b => s ++ byteToHex b - def startsWith (a pre : ByteArray) := if isLt : a.size < pre.size then false else pre.size.all fun i _ => a[i] = pre[i] -private protected def reprPrec (a : ByteArray) (p : Nat) := - Repr.addAppParen ("ByteArray.mk " ++ reprArg a.data) p - -instance : Repr ByteArray where - reprPrec := private ByteArray.reprPrec - end ByteArray #guard (ByteArray.empty |>.back!) = default @@ -57,6 +43,7 @@ end ByteArray #guard (ByteArray.empty |>.pop) = .empty #guard let a := ByteArray.empty |>.push 0 |>.push 1; (a |>.push 2 |>.pop) = a +public section namespace Strata.ByteArray def ofNatArray (a : Array Nat) : ByteArray := .mk (a.map UInt8.ofNat) @@ -66,6 +53,19 @@ instance : Lean.ToExpr ByteArray where toTypeExpr := private mkConst ``ByteArray toExpr a := private mkApp (mkConst ``ByteArray.ofNatArray) <| toExpr <| a.data.map (·.toNat) +private def byteToHex (b : UInt8) : String := + let cl : String := .ofList (Nat.toDigits 16 b.toNat) + if cl.length < 2 then "0" ++ cl else cl + +def asHex (a : ByteArray) : String := + a.foldl (init := "") fun s b => s ++ byteToHex b + +protected def reprPrec (a : ByteArray) (p : Nat) := + Repr.addAppParen ("ByteArray.mk " ++ reprArg a.data) p + +instance : Repr ByteArray where + reprPrec := ByteArray.reprPrec + def escapedBytes : Std.HashMap UInt8 Char := Std.HashMap.ofList [ (9, 't'), (10, 'n'), @@ -190,3 +190,4 @@ def unescapeBytes (s : String) : Except (String.ValidPos s × String.ValidPos s | .ok (a, _) => .ok a end Strata.ByteArray +end diff --git a/Strata/DDM/Util/Decimal.lean b/Strata/DDM/Util/Decimal.lean index 6e1ff0f49..f3b1ecfb0 100644 --- a/Strata/DDM/Util/Decimal.lean +++ b/Strata/DDM/Util/Decimal.lean @@ -5,11 +5,11 @@ -/ module -import Lean.ToExpr -import Strata.DDM.Util.Lean public import Lean.ToExpr -private def String.replicate (n : Nat) (c : Char) := n.repeat (a := "") (·.push c) +import Lean.ToExpr +import Strata.DDM.Util.Lean +import all Strata.DDM.Util.String public section namespace Strata diff --git a/Strata/DDM/Util/Fin.lean b/Strata/DDM/Util/Fin.lean index 1876e7675..16a5b5e86 100644 --- a/Strata/DDM/Util/Fin.lean +++ b/Strata/DDM/Util/Fin.lean @@ -6,9 +6,10 @@ module /- -Extra declarations in Fin namespace +Extra declarations in Fin namespace. + +These are private so we do not extend Lean's namespaces. -/ -public section namespace Fin instance {n} : Min (Fin n) where @@ -34,4 +35,3 @@ end Range def range (n : Nat) : Range n := .mk end Fin -end diff --git a/Strata/DDM/Util/Format.lean b/Strata/DDM/Util/Format.lean index 1811306b8..1cfaaaf3d 100644 --- a/Strata/DDM/Util/Format.lean +++ b/Strata/DDM/Util/Format.lean @@ -5,7 +5,7 @@ -/ module -import Strata.DDM.Util.String +import all Strata.DDM.Util.String namespace Std.Format diff --git a/Strata/DDM/Util/Graph/OutGraph.lean b/Strata/DDM/Util/Graph/OutGraph.lean new file mode 100644 index 000000000..cbf162718 --- /dev/null +++ b/Strata/DDM/Util/Graph/OutGraph.lean @@ -0,0 +1,39 @@ +module + +import all Strata.DDM.Util.Vector + +public section +namespace Strata + +structure OutGraph (nodeCount : Nat) where + /-- For each edge `s -> t` in the graph, we have `s ∈ edges[t]` -/ + edges : Vector (Array (Fin nodeCount)) nodeCount + deriving Inhabited, Repr + +namespace OutGraph + +abbrev Node n := Fin n + +protected def empty (n : Nat) : OutGraph n where + edges := .replicate n ∅ + +protected def addEdge {n} (g : OutGraph n) (f t : Node n) : OutGraph n := + { edges := Vector.modify g.edges ⟨t, by omega⟩ (·.push ⟨f, by omega⟩) + } + +protected def addEdge! {n} (g : OutGraph n) (f t : Nat) : OutGraph n := + if fp : f ≥ n then + @panic _ ⟨g⟩ s!"Invalid from edge {f}" + else if tp : t ≥ n then + @panic _ ⟨g⟩ s!"Invalid to edge {t}" + else + g.addEdge ⟨f, Nat.lt_of_not_le fp⟩ ⟨t, Nat.lt_of_not_le tp⟩ + +protected def ofEdges! (n : Nat) (edges : List (Nat × Nat)) : OutGraph n := + edges.foldl (fun g (f, t) => g.addEdge! f t) (.empty n) + +def nodesOut {n} (g : OutGraph n) (node : Node n) : Array (Node n) := + g.edges[node] + +end Strata.OutGraph +end diff --git a/Strata/DDM/Util/Graph/Preimage.lean b/Strata/DDM/Util/Graph/Preimage.lean new file mode 100644 index 000000000..9d77d5c60 --- /dev/null +++ b/Strata/DDM/Util/Graph/Preimage.lean @@ -0,0 +1,111 @@ +module +public import Strata.DDM.Util.Graph.OutGraph +public import Std.Data.HashSet + +import all Strata.DDM.Util.Array + +public section +namespace Strata.OutGraph + +private structure WorkSet (n : Nat) where + set : Vector Bool n + pending : Array (Fin n) + inv : ∀idx, idx ∈ pending → set[idx.val] = true + +namespace WorkSet + +private def remainingCount {n} (s : WorkSet n) : Nat := + (s.set.toArray.filter (!·)).size + s.pending.size + +private def empty {n} : WorkSet n := + { set := .replicate _ false + pending := #[] + inv := fun idx idxp => by simp only [Array.mem_empty_iff] at idxp + } + +private def addIdx {n} (s : WorkSet n) (idx : Fin n) : WorkSet n := + if p : s.set[idx] = true then + s + else + let { set, pending, inv } := s + { set := set.set idx true + pending := pending.push idx + inv := by grind + } + +private theorem remainingCount_addIdx {n} (s : WorkSet n) (idx : Fin n) : + (s.addIdx idx).remainingCount = s.remainingCount := by + simp only [WorkSet.addIdx] + if h : s.set[idx] = true then + simp [h] + else + have eq_false : s.set[idx.val] = false := + iff_of_eq (Bool.not_eq_true (s.set[idx.val])) |>.mp h + simp [-Vector.size_toArray, -Array.size_set, + remainingCount, Array.size_filter_set, eq_false] + have false_in : false ∈ s.set := ⟨Array.mem_of_getElem eq_false⟩ + have size_pos : (Array.filter (fun x => !x) s.set.toArray).size > 0 := by + simp [-Vector.size_toArray, Array.size_filter_pos_iff, false_in] + omega + + +@[inline] +private def pop {n} (s : WorkSet n) (p : s.pending.size > 0) : WorkSet n := + let { set, pending, inv } := s + { set + pending := pending.pop + inv := by + intro idx mem + apply inv + simp [Array.mem_pop_ne p, mem] + } + +private theorem remaining_count_pop {m} (s : WorkSet m) (p : s.pending.size > 0) : + (s.pop p).remainingCount + 1 = s.remainingCount := by + simp_all [pop, remainingCount] + grind + +end WorkSet + +private def preimage.aux {n} (g : OutGraph n) (s : WorkSet n) : Vector Bool n := + if p : s.pending.size > 0 then + let idx := s.pending.back + let s := s.pop p + let ops := g.edges[idx] + let s := ops.foldl (init := s) WorkSet.addIdx + preimage.aux g s + else + s.set +termination_by s.remainingCount +decreasing_by + rename_i u _ + have foldl_eq : ∀{n} (z : WorkSet n) (nodes : Array (Fin n)), + (Array.foldl WorkSet.addIdx z nodes).remainingCount = z.remainingCount := by + intro n z nodes + apply Array.foldl_induction (motive := fun _ (s : WorkSet n) => s.remainingCount = z.remainingCount) + case h0 => + simp + case hf => + intro sz s sp + simp [WorkSet.remainingCount_addIdx] + exact sp + simp [foldl_eq] + have inv := WorkSet.remaining_count_pop u p + omega + +/-- +`g.preimage_closure s` returns the nodes that have paths to some node in `s`. + +Specifically, let `v = g.preimage_closure s`, then `v[i] = true` iff there is +a path from `i` to any node in `s` via the edges in `g`. +-/ +def preimage_closure {n} + (g : OutGraph n) + (elts : Std.HashSet (Fin n)) + : Vector Bool n := + let s := .empty + let s := elts.fold (init := s) fun s idx => s.addIdx idx + preimage.aux g s + +end Strata.OutGraph +end diff --git a/Strata/DDM/Util/Graph/Tarjan.lean b/Strata/DDM/Util/Graph/Tarjan.lean index 7ce3f12b7..ddf94c207 100644 --- a/Strata/DDM/Util/Graph/Tarjan.lean +++ b/Strata/DDM/Util/Graph/Tarjan.lean @@ -4,40 +4,15 @@ SPDX-License-Identifier: Apache-2.0 OR MIT -/ module -import Strata.DDM.Util.Fin -import Strata.DDM.Util.Vector +public import Strata.DDM.Util.Graph.OutGraph -public section -namespace Strata - -structure OutGraph (nodeCount : Nat) where - edges : Vector (Array (Fin nodeCount)) nodeCount - deriving Inhabited, Repr - -namespace OutGraph - -abbrev Node n := Fin n - -protected def empty (n : Nat) : OutGraph n where - edges := .replicate n ∅ +import all Strata.DDM.Util.Fin +import all Strata.DDM.Util.Vector -protected def addEdge (g : OutGraph n) (f t : Node n) : OutGraph n := - { edges := g.edges.modify ⟨t, by omega⟩ (·.push ⟨f, by omega⟩) - } - -protected def addEdge! (g : OutGraph n) (f t : Nat) : OutGraph n := - if fp : f ≥ n then - @panic _ ⟨g⟩ s!"Invalid from edge {f}" - else if tp : t ≥ n then - @panic _ ⟨g⟩ s!"Invalid to edge {t}" - else - g.addEdge ⟨f, Nat.lt_of_not_le fp⟩ ⟨t, Nat.lt_of_not_le tp⟩ +meta import Strata.DDM.Util.Graph.OutGraph -protected def ofEdges! (n : Nat) (edges : List (Nat × Nat)) : OutGraph n := - edges.foldl (fun g (f, t) => g.addEdge! f t) (.empty n) - -def nodesOut (g : OutGraph n) (node : Node n) : Array (Node n) := - g.edges[node] +public section +namespace Strata.OutGraph private structure TarjanState (n : Nat) where index : Fin (n+1) := 0 @@ -48,10 +23,10 @@ private structure TarjanState (n : Nat) where components : Array (Array (Fin n)) := #[] deriving Inhabited -private def TarjanState.mergeLowlink (s : TarjanState n) (v : Fin n) (w : Fin n): TarjanState n := - { s with lowlinks := s.lowlinks.modify v (min s.lowlinks[w]) } +private def TarjanState.mergeLowlink {n} (s : TarjanState n) (v : Fin n) (w : Fin n): TarjanState n := + { s with lowlinks := Vector.modify s.lowlinks v (min s.lowlinks[w]) } -private def popTo (v : Fin n) (s : TarjanState n) (comp : Array (Fin n)) : TarjanState n := +private def popTo {n} (v : Fin n) (s : TarjanState n) (comp : Array (Fin n)) : TarjanState n := if p : s.stk.size > 0 then let w := s.stk[s.stk.size - 1] let s := { s with stk := s.stk.pop, onStack := s.onStack.set w false } @@ -63,7 +38,7 @@ private def popTo (v : Fin n) (s : TarjanState n) (comp : Array (Fin n)) : Tarja else panic "Unexpected empty stack" -private partial def strongconnect (g : OutGraph n) (v : Node n) (s : TarjanState n) : TarjanState n := +private partial def strongconnect {n} (g : OutGraph n) (v : Node n) (s : TarjanState n) : TarjanState n := -- Set the depth index for v to the smallest unused index let s := { s with index := s.index + 1, diff --git a/Strata/DDM/Util/Ion.lean b/Strata/DDM/Util/Ion.lean index 4f300962f..30dd1c8e2 100644 --- a/Strata/DDM/Util/Ion.lean +++ b/Strata/DDM/Util/Ion.lean @@ -10,13 +10,19 @@ public import Strata.DDM.Util.Ion.Deserialize public import Strata.DDM.Util.Ion.Serialize public import Strata.DDM.Util.Ion.SymbolTable -import Strata.DDM.Util.Fin +import all Strata.DDM.Util.ByteArray +import all Strata.DDM.Util.Fin import Strata.DDM.Util.Ion.Deserialize import Strata.DDM.Util.Ion.JSON public section namespace Ion +/-- +Returns true if this starts with the Ion binary version marker. +-/ +def isIonFile (bytes : ByteArray) : Bool := bytes.startsWith binaryVersionMarker + structure Position where indices : Array Nat := #[] deriving Repr @@ -145,7 +151,7 @@ def internAndSerialize (values : List (Ion String)) (symbols : SymbolTable := .s /-- Write a list of Ion values to file. -/ -def writeBinaryFile (path : System.FilePath) (values : List (Ion String)) (symbols : SymbolTable := system): IO Unit := do +def writeBinaryFile (path : System.FilePath) (values : List (Ion String)) (symbols : SymbolTable := .system): IO Unit := do IO.FS.writeBinFile path (internAndSerialize values symbols) end Ion diff --git a/Strata/DDM/Util/Ion/AST.lean b/Strata/DDM/Util/Ion/AST.lean index 79275a425..7d3f94bd9 100644 --- a/Strata/DDM/Util/Ion/AST.lean +++ b/Strata/DDM/Util/Ion/AST.lean @@ -13,6 +13,7 @@ public section namespace Ion export Strata (Decimal) +open Strata.ByteArray inductive CoreType | null diff --git a/Strata/DDM/Util/Ion/Serialize.lean b/Strata/DDM/Util/Ion/Serialize.lean index da550d6ba..09211ba7c 100644 --- a/Strata/DDM/Util/Ion/Serialize.lean +++ b/Strata/DDM/Util/Ion/Serialize.lean @@ -6,7 +6,7 @@ module public import Strata.DDM.Util.Ion.AST -import Strata.DDM.Util.ByteArray +import all Strata.DDM.Util.ByteArray namespace Strata.ByteArray diff --git a/Strata/DDM/Util/List.lean b/Strata/DDM/Util/List.lean index fc72f1194..c8143d850 100644 --- a/Strata/DDM/Util/List.lean +++ b/Strata/DDM/Util/List.lean @@ -4,8 +4,8 @@ SPDX-License-Identifier: Apache-2.0 OR MIT -/ module +set_option autoImplicit false -public section namespace List theorem sizeOf_pos {α} [SizeOf α] (l : List α) : sizeOf l > 0 := by @@ -34,7 +34,7 @@ theorem sizeOf_lt_of_mem_strict {α} [inst : SizeOf α] {as : List α} {a} (h : | tail _ _ ih => exact Nat.lt_trans ih (by simp +arith) @[simp] -theorem sizeOf_set [h : SizeOf α] (as : List α) (i : Nat) (v : α) : +theorem sizeOf_set {α} [h : SizeOf α] (as : List α) (i : Nat) (v : α) : sizeOf (as.set i v) = if p : i < as.length then sizeOf as + sizeOf v - sizeOf as[i] @@ -57,5 +57,22 @@ theorem sizeOf_set [h : SizeOf α] (as : List α) (i : Nat) (v : α) : case h_3 => simp +/-- +Rewrites non-emptty list membership to drop last and get last. +-/ +theorem mem_ne_as_dropLast {α} (a : α) {as : List α} (ne : as ≠ []) : + a ∈ as ↔ a ∈ as.dropLast ∨ a = as.getLast ne := + match as with + | [] => by + simp at ne + | [b] => by + simp + | b0 :: b1 :: bs => by + if h : a = b0 then + simp [h] + else + have ne' : b1 :: bs ≠ [] := by simp + have p := List.mem_ne_as_dropLast a ne' + grind + end List -end diff --git a/Strata/DDM/Util/OrderedSet.lean b/Strata/DDM/Util/OrderedSet.lean new file mode 100644 index 000000000..e765139f5 --- /dev/null +++ b/Strata/DDM/Util/OrderedSet.lean @@ -0,0 +1,49 @@ +module +public import Std.Data.HashSet + +public section +namespace Strata.DDM + +/-- +This is a hashset backed by an array so that we can fold +over the elements in the set. +-/ +structure OrderedSet (α : Type _) [BEq α] [Hashable α] where + private mk :: + private set : Std.HashSet α := {} + private values : Array α := #[] + +namespace OrderedSet + +/-- empty set. -/ +@[inline] +def empty {α} [BEq α] [Hashable α] : OrderedSet α := .mk {} #[] + +/-- +Return +-/ +def toArray {α} [BEq α] [Hashable α] (s : OrderedSet α) : Array α := s.values + +/-- Add an element to the set -/ +def insert {α} [BEq α] [Hashable α] (s : OrderedSet α) (a : α) : OrderedSet α := + if a ∈ s.set then + s + else + { set := s.set.insert a, values := s.values.push a } + +/-- +Add all reachable dialects +-/ +partial def addAllPostorder {α} [BEq α] [Hashable α] (pre : α → Array α) (s : OrderedSet α) (a : α) : OrderedSet α := + if a ∈ s.set then + s + else + let as := pre a + let s := { s with set := s.set.insert a } + let s := as.foldl (init := s) (addAllPostorder pre) + { s with values := s.values.push a } + +end OrderedSet + +end Strata.DDM +end diff --git a/Strata/DDM/Util/String.lean b/Strata/DDM/Util/String.lean index 7b4257581..aa7b1c30f 100644 --- a/Strata/DDM/Util/String.lean +++ b/Strata/DDM/Util/String.lean @@ -7,43 +7,12 @@ module import all Init.Data.String.Defs /- -This file contains auxillary definitions for String that could be -potentially useful to add. --/ +This file contains auxillary definitions for String. -public section -namespace Strata - -/-- -Return true if this is a non-printable 8-bit character +If they are general purpose, we keep them as private symbols +that could be imported via import all. Otherwise they are +added to Strata. -/ -private def useXHex ( c : Char) : Bool := - c < '\x20' ∨ '\x7f' ≤ c ∧ (c < '\xa1' ∨ c == '\xad') - -private def escapeStringLitAux (acc : String) (c : Char) : String := - if c == '"' then - acc ++ "\\\"" - else if c == '\\' then - acc ++ "\\\\" - else if c == '\n' then - acc ++ "\\n" - else if c == '\r' then - acc ++ "\\r" - else if c == '\t' then - acc ++ "\\t" - else if useXHex c then - let i := c.toNat - let digits := Nat.toDigits 16 i - if i < 16 then - s!"{acc}\\x0{digits[0]!}" - else - assert! digits.length = 2 - s!"{acc}\\x{digits[0]!}{digits[1]!}" - else - acc.push c - -def escapeStringLit (s : String) : String := - s.foldl escapeStringLitAux "\"" ++ "\"" namespace String @@ -51,11 +20,7 @@ namespace String theorem isEmpty_eq (s : _root_.String) : s.isEmpty = (s == "") := by simp only [String.isEmpty, BEq.beq, String.utf8ByteSize_eq_zero_iff] -end String - -end Strata - -namespace String +def replicate (n : Nat) (c : Char) := n.repeat (a := "") (·.push c) /-- Indicates s has a substring at the given index. @@ -119,4 +84,40 @@ info: [""] #eval "".splitLines end String + +public section +namespace Strata + +/-- +Return true if this is a non-printable 8-bit character +-/ +private def useXHex ( c : Char) : Bool := + c < '\x20' ∨ '\x7f' ≤ c ∧ (c < '\xa1' ∨ c == '\xad') + +private def escapeStringLitAux (acc : String) (c : Char) : String := + if c == '"' then + acc ++ "\\\"" + else if c == '\\' then + acc ++ "\\\\" + else if c == '\n' then + acc ++ "\\n" + else if c == '\r' then + acc ++ "\\r" + else if c == '\t' then + acc ++ "\\t" + else if useXHex c then + let i := c.toNat + let digits := Nat.toDigits 16 i + if i < 16 then + s!"{acc}\\x0{digits[0]!}" + else + assert! digits.length = 2 + s!"{acc}\\x{digits[0]!}{digits[1]!}" + else + acc.push c + +def escapeStringLit (s : String) : String := + s.foldl escapeStringLitAux "\"" ++ "\"" + +end Strata end diff --git a/Strata/DDM/Util/Vector.lean b/Strata/DDM/Util/Vector.lean index 15308f871..5d5cfa3b1 100644 --- a/Strata/DDM/Util/Vector.lean +++ b/Strata/DDM/Util/Vector.lean @@ -4,17 +4,16 @@ SPDX-License-Identifier: Apache-2.0 OR MIT -/ module +set_option autoImplicit false -public section namespace Vector @[inline] -def modify! (v : Vector α n) (i : Nat) (f : α → α) : Vector α n where +def modify! {α n} (v : Vector α n) (i : Nat) (f : α → α) : Vector α n where toArray := v.toArray.modify i f size_toArray := Eq.trans Array.size_modify v.size_toArray @[inline] -def modify (v : Vector α n) (i : Fin n) (f : α → α) : Vector α n := v.modify! i.val f +def modify {α n} (v : Vector α n) (i : Fin n) (f : α → α) : Vector α n := modify! v i.val f end Vector -end diff --git a/Strata/Languages/Boogie/DDMTransform/Parse.lean b/Strata/Languages/Boogie/DDMTransform/Parse.lean index 43f60e314..a223912d8 100644 --- a/Strata/Languages/Boogie/DDMTransform/Parse.lean +++ b/Strata/Languages/Boogie/DDMTransform/Parse.lean @@ -3,7 +3,6 @@ SPDX-License-Identifier: Apache-2.0 OR MIT -/ - import Strata.DDM.Integration.Lean import Strata.DDM.Util.Format import Strata.Languages.Boogie.Boogie diff --git a/Strata/Languages/Python/PythonDialect.lean b/Strata/Languages/Python/PythonDialect.lean index 8ae84a57d..7fcbc4ebd 100644 --- a/Strata/Languages/Python/PythonDialect.lean +++ b/Strata/Languages/Python/PythonDialect.lean @@ -6,11 +6,8 @@ import Strata.DDM.Integration.Lean - namespace Strata - - namespace Python #load_dialect "../../../Tools/Python/test_results/dialects/Python.dialect.st.ion" diff --git a/StrataMain.lean b/StrataMain.lean index 53c5a9298..ced9e9011 100644 --- a/StrataMain.lean +++ b/StrataMain.lean @@ -7,6 +7,7 @@ -- Executable with utilities for working with Strata files. import Strata.DDM.Elab import Strata.DDM.Ion +import Strata.DDM.Util.ByteArray import Strata.Util.IO import Strata.Languages.Python.Python @@ -104,7 +105,7 @@ def readStrataIon (fm : Strata.DialectFileMap) (path : System.FilePath) (bytes : def readFile (fm : Strata.DialectFileMap) (path : System.FilePath) : IO (Strata.Elab.LoadedDialects × Strata.DialectOrProgram) := do let bytes ← Strata.Util.readBinInputSource path.toString let displayPath : System.FilePath := Strata.Util.displayName path.toString - if bytes.startsWith Ion.binaryVersionMarker then + if Ion.isIonFile bytes then readStrataIon fm displayPath bytes else readStrataText fm displayPath bytes @@ -171,7 +172,7 @@ def diffCommand : Command where def readPythonStrata (path : String) : IO Strata.Program := do let bytes ← Strata.Util.readBinInputSource path - if ! bytes.startsWith Ion.binaryVersionMarker then + if ! Ion.isIonFile bytes then exitFailure s!"pyAnalyze expected Ion file" match Strata.Program.fromIon Strata.Python.Python_map Strata.Python.Python.name bytes with | .ok p => pure p diff --git a/StrataTest/DDM/Gen.lean b/StrataTest/DDM/Gen.lean index ed073024f..717f495ab 100644 --- a/StrataTest/DDM/Gen.lean +++ b/StrataTest/DDM/Gen.lean @@ -3,14 +3,19 @@ SPDX-License-Identifier: Apache-2.0 OR MIT -/ +module -import Strata.DDM.Integration.Lean +import Strata.DDM.Integration.Lean.Gen +import Strata.DDM.Integration.Lean.HashCommands + +public import Strata.DDM.AST +public import Strata.DDM.Integration.Lean.OfAstM namespace Strata class IsAST (β : Type → Type) (M : outParam (Type → Type)) where - toAst [Inhabited α] : β α → M α - ofAst [Inhabited α] [Repr α] : M α → OfAstM (β α) + toAst {α} [Inhabited α] : β α → M α + ofAst {α} [Inhabited α] [Repr α] : M α → OfAstM (β α) end Strata @@ -55,8 +60,17 @@ op mkMutACommaSep (a : CommaSepBy MutA) : MutACommaSep => a; namespace TestDialect +set_option trace.Strata.generator true + #strata_gen TestDialect +#print TestDialect.TestDialectType + +#eval ``TestDialectType +public section + +end + /-- info: inductive TestDialect.test : Type → Type number of parameters: 1 @@ -98,6 +112,32 @@ TestDialect.TypeP.type : {α : Type} → α → TypeP α #guard_msgs in #print TypeP +/-- +info: def TestDialect.TypeP.ann : {α : Type} → TypeP α → α := +fun {α} a => TypeP.casesOn a (fun tp => tp.ann) fun ann => ann +-/ +#guard_msgs in +#print TypeP.ann + +/-- +info: inductive TestDialect.Expr : Type → Type +number of parameters: 1 +constructors: +TestDialect.Expr.fvar : {α : Type} → α → Nat → Expr α +TestDialect.Expr.trueExpr : {α : Type} → α → Expr α +TestDialect.Expr.and : {α : Type} → α → Expr α → Expr α → Expr α +TestDialect.Expr.lambda : {α : Type} → α → TestDialectType α → Bindings α → Expr α → Expr α +-/ +#guard_msgs in +#print Expr + +/-- +info: def TestDialect.Expr.ann : {α : Type} → Expr α → α := +fun {α} a => Expr.casesOn a (fun ann idx => ann) (fun ann => ann) (fun ann x y => ann) fun ann tp b res => ann +-/ +#guard_msgs in +#print Expr.ann + /-- info: Strata.ExprF.fvar () 1 -/