From 6450d2c318c6fca001bd9821747aed88f580d185 Mon Sep 17 00:00:00 2001 From: Kyrylo Simonov Date: Wed, 19 Feb 2025 09:33:51 -0600 Subject: [PATCH 1/7] Separate as() and into() --- src/FunSQL.jl | 1 + src/link.jl | 139 +++++++++++++++++++++--------------------- src/nodes.jl | 1 + src/nodes/as.jl | 17 +++--- src/nodes/hide.jl | 42 +++++++++++++ src/nodes/internal.jl | 20 ++---- src/nodes/into.jl | 28 +++++++++ src/resolve.jl | 34 +++++------ src/translate.jl | 66 ++++++++------------ 9 files changed, 196 insertions(+), 152 deletions(-) create mode 100644 src/nodes/hide.jl create mode 100644 src/nodes/into.jl diff --git a/src/FunSQL.jl b/src/FunSQL.jl index fb6a5ca8..618c6610 100644 --- a/src/FunSQL.jl +++ b/src/FunSQL.jl @@ -56,6 +56,7 @@ export funsql_group, funsql_highlight, funsql_in, + funsql_into, funsql_iterate, funsql_is_not_null, funsql_is_null, diff --git a/src/link.jl b/src/link.jl index 00015f41..3736efea 100644 --- a/src/link.jl +++ b/src/link.jl @@ -123,19 +123,15 @@ function dismantle(n::GroupNode, ctx) Group(by = by′, sets = n.sets, name = n.name, label_map = n.label_map, tail = tail′) end -function dismantle(n::IterateNode, ctx) +function dismantle(n::IntoNode, ctx) tail′ = dismantle(ctx) - iterator′ = dismantle(n.iterator, ctx) - Iterate(iterator = iterator′, tail = tail′) + Into(name = n.name, tail = tail′) end -function dismantle(n::JoinNode, ctx) - rt = row_type(n.joinee) - router = JoinRouter(Set(keys(rt.fields)), !isa(rt.group, EmptyType)) +function dismantle(n::IterateNode, ctx) tail′ = dismantle(ctx) - joinee′ = dismantle(n.joinee, ctx) - on′ = dismantle_scalar(n.on, ctx) - RoutedJoin(joinee = joinee′, on = on′, router = router, left = n.left, right = n.right, optional = n.optional, tail = tail′) + iterator′ = dismantle(n.iterator, ctx) + Iterate(iterator = iterator′, tail = tail′) end function dismantle(n::LimitNode, ctx) @@ -181,6 +177,13 @@ function dismantle_scalar(n::ResolvedNode, ctx) end end +function dismantle(n::RoutedJoinNode, ctx) + tail′ = dismantle(ctx) + joinee′ = dismantle(n.joinee, ctx) + on′ = dismantle_scalar(n.on, ctx) + RoutedJoin(joinee = joinee′, on = on′, name = n.name, left = n.left, right = n.right, optional = n.optional, tail = tail′) +end + function dismantle(n::SelectNode, ctx) tail′ = dismantle(ctx) args′ = dismantle_scalar(n.args, ctx) @@ -232,16 +235,7 @@ function link(n::AppendNode, ctx) end function link(n::AsNode, ctx) - refs = SQLQuery[] - for ref in ctx.refs - if @dissect(ref, (local tail) |> Nested(name = (local name))) - @assert name == n.name - push!(refs, tail) - else - error() - end - end - tail′ = link(ctx.tail, ctx, refs) + tail′ = link(ctx) As(name = n.name, tail = tail′) end @@ -289,10 +283,8 @@ function link(n::FromIterateNode, ctx) end function link(n::FromTableExpressionNode, ctx) - refs = ctx.cte_refs[(n.name, n.depth)] - for ref in ctx.refs - push!(refs, Nested(name = n.name, tail = ref)) - end + cte_refs = ctx.cte_refs[(n.name, n.depth)] + append!(cte_refs, ctx.refs) n end @@ -333,6 +325,20 @@ function link(n::GroupNode, ctx) Group(by = n.by, sets = n.sets, name = n.name, label_map = n.label_map, tail = tail′) end +function link(n::IntoNode, ctx) + refs = SQLQuery[] + for ref in ctx.refs + if @dissect(ref, (local tail) |> Nested(name = (local name))) + @assert name == n.name + push!(refs, tail) + else + error() + end + end + tail′ = link(ctx.tail, ctx, refs) + Into(name = n.name, tail = tail′) +end + function link(n::IterateNode, ctx) iterator′ = n.iterator defs = copy(ctx.defs) @@ -364,53 +370,6 @@ function link(n::IterateNode, ctx) Padding(tail = q′) end -function route(r::JoinRouter, ref::SQLQuery) - if @dissect(ref, Nested(name = (local name))) && name in r.label_set - return 1 - end - if @dissect(ref, Get(name = (local name))) && name in r.label_set - return 1 - end - if @dissect(ref, Agg()) && r.group - return 1 - end - return -1 -end - -function link(n::RoutedJoinNode, ctx) - lrefs = SQLQuery[] - rrefs = SQLQuery[] - for ref in ctx.refs - turn = route(n.router, ref) - push!(turn < 0 ? lrefs : rrefs, ref) - end - if n.optional && isempty(rrefs) - return link(ctx) - end - ln_ext_refs = length(lrefs) - rn_ext_refs = length(rrefs) - refs′ = SQLQuery[] - lateral_refs = SQLQuery[] - gather!(n.joinee, ctx, lateral_refs) - append!(lrefs, lateral_refs) - lateral = !isempty(lateral_refs) - gather!(n.on, ctx, refs′) - for ref in refs′ - turn = route(n.router, ref) - push!(turn < 0 ? lrefs : rrefs, ref) - end - tail′ = Linked(lrefs, ln_ext_refs, tail = link(ctx.tail, ctx, lrefs)) - joinee′ = Linked(rrefs, rn_ext_refs, tail = link(n.joinee, ctx, rrefs)) - RoutedJoin( - joinee = joinee′, - on = n.on, - router = n.router, - left = n.left, - right = n.right, - lateral = lateral, - tail = tail′) -end - function link(n::LimitNode, ctx) tail′ = Linked(ctx.refs, tail = link(ctx)) Limit(offset = n.offset, limit = n.limit, tail = tail′) @@ -459,6 +418,46 @@ function link(n::PartitionNode, ctx) Partition(by = n.by, order_by = n.order_by, frame = n.frame, name = n.name, tail = tail′) end +function link(n::RoutedJoinNode, ctx) + lrefs = SQLQuery[] + rrefs = SQLQuery[] + for ref in ctx.refs + if @dissect(ref, Nested(name = (local name))) && name === n.name + push!(rrefs, ref) + else + push!(lrefs, ref) + end + end + if n.optional && isempty(rrefs) + return link(ctx) + end + ln_ext_refs = length(lrefs) + rn_ext_refs = length(rrefs) + refs′ = SQLQuery[] + lateral_refs = SQLQuery[] + gather!(n.joinee, ctx, lateral_refs) + append!(lrefs, lateral_refs) + lateral = !isempty(lateral_refs) + gather!(n.on, ctx, refs′) + for ref in refs′ + if @dissect(ref, Nested(name = (local name))) && name === n.name + push!(rrefs, ref) + else + push!(lrefs, ref) + end + end + tail′ = Linked(lrefs, ln_ext_refs, tail = link(ctx.tail, ctx, lrefs)) + joinee′ = Linked(rrefs, rn_ext_refs, tail = link(Into(name = n.name, tail = n.joinee), ctx, rrefs)) + RoutedJoin( + joinee = joinee′, + on = n.on, + name = n.name, + left = n.left, + right = n.right, + lateral = lateral, + tail = tail′) +end + function link(n::SelectNode, ctx) refs = SQLQuery[] gather!(n.args, ctx, refs) diff --git a/src/nodes.jl b/src/nodes.jl index ab470726..5fc547c0 100644 --- a/src/nodes.jl +++ b/src/nodes.jl @@ -913,6 +913,7 @@ include("nodes/get.jl") include("nodes/group.jl") include("nodes/highlight.jl") include("nodes/internal.jl") +include("nodes/into.jl") include("nodes/iterate.jl") include("nodes/join.jl") include("nodes/limit.jl") diff --git a/src/nodes/as.jl b/src/nodes/as.jl index 3a4a5db7..d1a61a2d 100644 --- a/src/nodes/as.jl +++ b/src/nodes/as.jl @@ -16,8 +16,7 @@ AsNode(name) = As(name; tail = nothing) name => tail -In a scalar context, `As` specifies the name of the output column. When -applied to tabular data, `As` wraps the data in a nested record. +`As` specifies the name of the output column. The arrow operator (`=>`) is a shorthand notation for `As`. @@ -35,19 +34,19 @@ SELECT "person_1"."person_id" AS "id" FROM "person" AS "person_1" ``` -*Show all patients together with their state of residence.* +*Show all patients together with their primary care provider.* ```jldoctest -julia> person = SQLTable(:person, columns = [:person_id, :year_of_birth, :location_id]); +julia> person = SQLTable(:person, columns = [:person_id, :year_of_birth, :provider_id]); -julia> location = SQLTable(:location, columns = [:location_id, :state]); +julia> provider = SQLTable(:provider, columns = [:provider_id, :provider_name]); julia> q = From(:person) |> - Join(From(:location) |> As(:location), - on = Get.location_id .== Get.location.location_id) |> - Select(Get.person_id, Get.location.state); + Join(From(:provider) |> As(:pcp), + on = Get.provider_id .== Get.pcp.provider_id) |> + Select(Get.person_id, Get.pcp.provider_name); -julia> print(render(q, tables = [person, location])) +julia> print(render(q, tables = [person, provider])) SELECT "person_1"."person_id", "location_1"."state" diff --git a/src/nodes/hide.jl b/src/nodes/hide.jl new file mode 100644 index 00000000..36ced02c --- /dev/null +++ b/src/nodes/hide.jl @@ -0,0 +1,42 @@ +# Hide node + +mutable struct HideNode <: TabularNode + over::Union{SQLNode, Nothing} + names::Vector{Symbol} + label_map::FunSQL.OrderedDict{Symbol, Int} + + function HideNode(; over = nothing, names = [], label_map = nothing) + if label_map !== nothing + new(over, names, label_map) + else + n = new(over, names, FunSQL.OrderedDict{Symbol, Int}()) + for (i, name) in enumerate(n.names) + if name in keys(n.label_map) + err = FunSQL.DuplicateLabelError(name, path = [n]) + throw(err) + end + n.label_map[name] = i + end + n + end + end +end + +HideNode(names...; over = nothing) = + HideNode(over = over, names = Symbol[names...]) + +Hide(args...; kws...) = + HideNode(args...; kws...) |> SQLNode + +const funsql_hide = Hide + +dissect(scr::Symbol, ::typeof(Hide), pats::Vector{Any}) = + dissect(scr, HideNode, pats) + +function FunSQL.PrettyPrinting.quoteof(n::HideNode, ctx::FunSQL.QuoteContext) + ex = Expr(:call, nameof(Hide), quoteof(n.names, ctx)...) + if n.over !== nothing + ex = Expr(:call, :|>, FunSQL.quoteof(n.over, ctx), ex) + end + ex +end diff --git a/src/nodes/internal.jl b/src/nodes/internal.jl index 91f866fd..931afb97 100644 --- a/src/nodes/internal.jl +++ b/src/nodes/internal.jl @@ -203,29 +203,21 @@ PrettyPrinting.quoteof(n::FromFunctionNode, ctx::QuoteContext) = # Annotated Join node. -struct JoinRouter - label_set::Set{Symbol} - group::Bool -end - -PrettyPrinting.quoteof(r::JoinRouter) = - Expr(:call, :JoinRouter, quoteof(r.label_set), quoteof(r.group)) - struct RoutedJoinNode <: TabularNode joinee::SQLQuery on::SQLQuery - router::JoinRouter + name::Symbol left::Bool right::Bool lateral::Bool optional::Bool - RoutedJoinNode(; joinee, on, router, left, right, lateral = false, optional = false) = - new(joinee, on, router, left, right, lateral, optional) + RoutedJoinNode(; joinee, on, name = label(joinee), left, right, lateral = false, optional = false) = + new(joinee, on, name, left, right, lateral, optional) end -RoutedJoinNode(joinee, on; router, left = false, right = false, lateral = false, optional = false) = - RoutedJoinNode(joinee = joinee, on = on, router, left = left, right = right, lateral = lateral, optional = optional) +RoutedJoinNode(joinee, on; name = label(joinee), left = false, right = false, lateral = false, optional = false) = + RoutedJoinNode(name = name, on = on, router, left = left, right = right, lateral = lateral, optional = optional) const RoutedJoin = SQLQueryCtor{RoutedJoinNode}(:RoutedJoin) @@ -234,7 +226,7 @@ function PrettyPrinting.quoteof(n::RoutedJoinNode, ctx::QuoteContext) if !ctx.limit push!(ex.args, quoteof(n.joinee, ctx)) push!(ex.args, quoteof(n.on, ctx)) - push!(ex.args, Expr(:kw, :router, quoteof(n.router))) + push!(ex.args, Expr(:kw, :name, QuoteNode(n.name))) if n.left push!(ex.args, Expr(:kw, :left, n.left)) end diff --git a/src/nodes/into.jl b/src/nodes/into.jl new file mode 100644 index 00000000..f421ea1e --- /dev/null +++ b/src/nodes/into.jl @@ -0,0 +1,28 @@ +# Wrap the output into a nested record. + +mutable struct IntoNode <: TabularNode + name::Symbol + + IntoNode(; name::Union{Symbol, AbstractString}) = + new(Symbol(name)) +end + +IntoNode(name) = + IntoNode(; name) + +""" + Into(; name, tail = nothing) + Into(name; tail = nothing) + +`Into` wraps output columns in a nested record. +""" +const Into = SQLQueryCtor{IntoNode}(:Into) + +const funsql_into = Into + +function PrettyPrinting.quoteof(n::IntoNode, ctx::QuoteContext) + Expr(:call, :Into, quoteof(n.name)) +end + +label(n::IntoNode) = + n.name diff --git a/src/resolve.jl b/src/resolve.jl index 73bf67ac..56136edb 100644 --- a/src/resolve.jl +++ b/src/resolve.jl @@ -178,9 +178,8 @@ end function resolve(n::AsNode, ctx) tail′ = resolve(ctx) - t = row_type(tail′) q′ = As(name = n.name, tail = tail′) - Resolved(RowType(FieldTypeMap(n.name => t)), tail = q′) + Resolved(type(tail′), tail = q′) end function resolve_scalar(n::AsNode, ctx) @@ -401,6 +400,13 @@ resolve(::HighlightNode, ctx) = resolve_scalar(::HighlightNode, ctx) = resolve_scalar(ctx) +function resolve(n::IntoNode, ctx) + tail′ = resolve(ctx) + t = row_type(tail′) + q′ = Into(name = n.name, tail = tail′) + Resolved(RowType(FieldTypeMap(n.name => t)), tail = q′) +end + function resolve(n::IterateNode, ctx) tail′ = resolve(ResolveContext(ctx, knot_type = nothing, implicit_knot = false)) t = row_type(tail′) @@ -418,21 +424,18 @@ end function resolve(n::JoinNode, ctx) tail′ = resolve(ctx) lt = row_type(tail′) + name = label(n.joinee) joinee′ = resolve(n.joinee, ResolveContext(ctx, row_type = lt, implicit_knot = false)) rt = row_type(joinee′) fields = FieldTypeMap() for (f, ft) in lt.fields - fields[f] = get(rt.fields, f, ft) + fields[f] = ft end - for (f, ft) in rt.fields - if !haskey(fields, f) - fields[f] = ft - end - end - group = rt.group isa EmptyType ? lt.group : rt.group + fields[name] = rt + group = lt.group t = RowType(fields, group) on′ = resolve_scalar(n.on, ctx, t) - q′ = Join(joinee = joinee′, on = on′, left = n.left, right = n.right, optional = n.optional, tail = tail′) + q′ = RoutedJoin(joinee = joinee′, on = on′, name = name, left = n.left, right = n.right, optional = n.optional, tail = tail′) Resolved(t, tail = q′) end @@ -532,16 +535,7 @@ function resolve(n::Union{WithNode, WithExternalNode}, ctx) v = get(ctx.cte_types, name, nothing) depth = 1 + (v !== nothing ? v[1] : 0) t = row_type(args′[i]) - cte_t = get(t.fields, name, EmptyType()) - if !(cte_t isa RowType) - throw( - ReferenceError( - REFERENCE_ERROR_TYPE.INVALID_TABLE_REFERENCE, - name = name, - path = get_path(ctx))) - - end - cte_types′ = Base.ImmutableDict(cte_types′, name => (depth, cte_t)) + cte_types′ = Base.ImmutableDict(cte_types′, name => (depth, t)) end ctx′ = ResolveContext(ctx, cte_types = cte_types′) tail′ = resolve(ctx′) diff --git a/src/translate.jl b/src/translate.jl index 84b3bd79..0e4a16fb 100644 --- a/src/translate.jl +++ b/src/translate.jl @@ -427,26 +427,8 @@ function assemble(n::AppendNode, ctx) Assemblage(a_name, s, repl = repl, cols = dummy_cols) end -function assemble(n::AsNode, ctx) - refs′ = SQLQuery[] - for ref in ctx.refs - if @dissect(ref, (local tail) |> Nested()) - push!(refs′, tail) - else - push!(refs′, ref) - end - end - base = assemble(TranslateContext(ctx, refs = refs′)) - repl′ = Dict{SQLQuery, Symbol}() - for ref in ctx.refs - if @dissect(ref, (local tail) |> Nested()) - repl′[ref] = base.repl[tail] - else - repl′[ref] = base.repl[ref] - end - end - Assemblage(n.name, base.syntax, cols = base.cols, repl = repl′) -end +assemble(n::AsNode, ctx) = + assemble(ctx) function assemble(n::BindNode, ctx) vars′ = ctx.vars @@ -530,21 +512,12 @@ end assemble(::FromNothingNode, ctx) = assemble(nothing, ctx) -function unwrap_repl(a::Assemblage) - repl′ = Dict{SQLQuery, Symbol}() - for (ref, name) in a.repl - @dissect(ref, (local tail) |> Nested()) || error() - repl′[tail] = name - end - Assemblage(a.name, a.syntax, cols = a.cols, repl = repl′) -end - function assemble(n::FromTableExpressionNode, ctx) cte_a = ctx.ctes[ctx.cte_map[(n.name, n.depth)]] alias = allocate_alias(ctx, n.name) tbl = convert(SQLSyntax, (cte_a.qualifiers, cte_a.name)) s = FROM(AS(name = alias, tail = tbl)) - subs = make_subs(unwrap_repl(cte_a.a), alias) + subs = make_subs(cte_a.a, alias) trns = Pair{SQLQuery, SQLSyntax}[] for ref in ctx.refs push!(trns, ref => subs[ref]) @@ -675,6 +648,27 @@ function assemble(n::GroupNode, ctx) return Assemblage(base.name, s, cols = cols, repl = repl) end +function assemble(n::IntoNode, ctx) + refs′ = SQLQuery[] + for ref in ctx.refs + if @dissect(ref, (local tail) |> Nested()) + push!(refs′, tail) + else + push!(refs′, ref) + end + end + base = assemble(TranslateContext(ctx, refs = refs′)) + repl′ = Dict{SQLQuery, Symbol}() + for ref in ctx.refs + if @dissect(ref, (local tail) |> Nested()) + repl′[ref] = base.repl[tail] + else + repl′[ref] = base.repl[ref] + end + end + Assemblage(n.name, base.syntax, cols = base.cols, repl = repl′) +end + function assemble(n::IterateNode, ctx) ctx′ = TranslateContext(ctx, vars = Base.ImmutableDict{Tuple{Symbol, Int}, SQLSyntax}()) left = assemble(ctx) @@ -883,22 +877,16 @@ function assemble(n::RoutedJoinNode, ctx) right = assemble(n.joinee, ctx) end if @dissect(right.syntax, (local joinee = (ID() || AS())) |> FROM()) && (!n.left || _outer_safe(right)) - for (ref, name) in right.repl - subs[ref] = right.cols[name] - end + right_alias = nothing if ctx.catalog.dialect.has_implicit_lateral lateral = false end else right_alias = allocate_alias(ctx, right) joinee = AS(name = right_alias, tail = complete(right)) - right_cache = Dict{Symbol, SQLSyntax}() - for (ref, name) in right.repl - subs[ref] = get(right_cache, name) do - ID(name = name, tail = right_alias) - end - end end + right_subs = make_subs(right, right_alias) + merge!(subs, right_subs) on = translate(n.on, ctx, subs) s = JOIN(joinee = joinee, on = on, left = n.left, right = n.right, lateral = lateral, tail = tail) trns = Pair{SQLQuery, SQLSyntax}[] From a70eaff4d98242fafd7019d924bbf15362b96a4e Mon Sep 17 00:00:00 2001 From: Kyrylo Simonov Date: Wed, 19 Feb 2025 21:17:08 -0600 Subject: [PATCH 2/7] Make nested records visible --- src/link.jl | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/link.jl b/src/link.jl index 3736efea..d79cdcbb 100644 --- a/src/link.jl +++ b/src/link.jl @@ -25,16 +25,26 @@ struct LinkContext knot_refs) end -function link(q::SQLQuery) - @dissect(q, (local tail) |> WithContext(catalog = (local catalog))) || throw(IllFormedError()) - ctx = LinkContext(catalog) - t = row_type(tail) +function _select(t::RowType) refs = SQLQuery[] for (f, ft) in t.fields if ft isa ScalarType push!(refs, Get(f)) + else + nested_refs = _select(ft) + for nested_ref in nested_refs + push!(refs, Nested(over = nested_ref, name = f)) + end end end + refs +end + +function link(q::SQLQuery) + @dissect(q, (local tail) |> WithContext(catalog = (local catalog))) || throw(IllFormedError()) + ctx = LinkContext(catalog) + t = row_type(tail) + refs = _select(t) tail′ = Linked(refs, tail = link(dismantle(tail, ctx), ctx, refs)) WithContext(tail = tail′, catalog = catalog, defs = ctx.defs) end @@ -555,12 +565,9 @@ end function gather!(n::IsolatedNode, ctx) def = ctx.defs[n.idx] !@dissect(def, Linked()) || return - refs = SQLQuery[] - for (f, ft) in n.type.fields - if ft isa ScalarType - push!(refs, Get(f)) - break - end + refs = _select(n.type) + if !isempty(refs) + refs = refs[1:1] end def′ = Linked(refs, tail = link(def, ctx, refs)) ctx.defs[n.idx] = def′ From 8149fcfc8f5d41fc88e4baf5f7d6bd258f50e2ad Mon Sep 17 00:00:00 2001 From: Kyrylo Simonov Date: Sat, 22 Feb 2025 13:58:58 -0600 Subject: [PATCH 3/7] Join: add swap option --- src/nodes/join.jl | 25 ++++++++++++++++--------- src/resolve.jl | 4 ++++ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/nodes/join.jl b/src/nodes/join.jl index c1bc2eb4..372d2272 100644 --- a/src/nodes/join.jl +++ b/src/nodes/join.jl @@ -6,21 +6,22 @@ mutable struct JoinNode <: TabularNode left::Bool right::Bool optional::Bool + swap::Bool - JoinNode(; joinee, on, left = false, right = false, optional = false) = - new(joinee, on, left, right, optional) + JoinNode(; joinee, on, left = false, right = false, optional = false, swap = false) = + new(joinee, on, left, right, optional, swap) end -JoinNode(joinee; on, left = false, right = false, optional = false) = - JoinNode(; joinee, on, left, right, optional) +JoinNode(joinee; on, left = false, right = false, optional = false, swap = false) = + JoinNode(; joinee, on, left, right, optional, swap) -JoinNode(joinee, on; left = false, right = false, optional = false) = - JoinNode(; joinee, on, left, right, optional) +JoinNode(joinee, on; left = false, right = false, optional = false, swap = false) = + JoinNode(; joinee, on, left, right, optional, swap) """ - Join(; joinee, on, left = false, right = false, optional = false) - Join(joinee; on, left = false, right = false, optional = false) - Join(joinee, on; left = false, right = false, optional = false) + Join(; joinee, on, left = false, right = false, optional = false, swap = false) + Join(joinee; on, left = false, right = false, optional = false, swap = false) + Join(joinee, on; left = false, right = false, optional = false, swap = false) `Join` correlates two input datasets. @@ -102,8 +103,14 @@ function PrettyPrinting.quoteof(n::JoinNode, ctx::QuoteContext) if n.optional push!(ex.args, Expr(:kw, :optional, n.optional)) end + if n.swap + push!(ex.args, Expr(:kw, :swap, n.swap)) + end else push!(ex.args, :…) end ex end + +label(n::JoinNode) = + n.swap ? label(n.joinee) : label(n.over) diff --git a/src/resolve.jl b/src/resolve.jl index 56136edb..bcdb6e21 100644 --- a/src/resolve.jl +++ b/src/resolve.jl @@ -422,6 +422,10 @@ function resolve(n::IterateNode, ctx) end function resolve(n::JoinNode, ctx) + if n.swap + ctx′ = ResolveContext(Ctx, tail = n.joinee) + return resolve(JoinNode(joinee = ctx.tail, on = n.on, left = n.right, right = n.left, optional = n.optional), ctx′) + end tail′ = resolve(ctx) lt = row_type(tail′) name = label(n.joinee) From 9edf1591e044a7b2e3ca76074b2207d209ec9a11 Mon Sep 17 00:00:00 2001 From: Kyrylo Simonov Date: Sat, 22 Feb 2025 17:18:45 -0600 Subject: [PATCH 4/7] Add Show/Hide combinators --- src/FunSQL.jl | 2 ++ src/link.jl | 4 +++- src/nodes.jl | 15 +++++++++------ src/nodes/internal.jl | 2 +- src/nodes/show.jl | 38 ++++++++++++++++++++++++++++++++++++++ src/resolve.jl | 29 ++++++++++++++++++++++++++++- src/types.jl | 30 +++++++++++++++++++++++------- 7 files changed, 104 insertions(+), 16 deletions(-) create mode 100644 src/nodes/show.jl diff --git a/src/FunSQL.jl b/src/FunSQL.jl index 618c6610..56e4378a 100644 --- a/src/FunSQL.jl +++ b/src/FunSQL.jl @@ -54,6 +54,7 @@ export funsql_from, funsql_fun, funsql_group, + funsql_hide, funsql_highlight, funsql_in, funsql_into, @@ -83,6 +84,7 @@ export funsql_rank, funsql_row_number, funsql_select, + funsql_show, funsql_sort, funsql_sum, funsql_with diff --git a/src/link.jl b/src/link.jl index d79cdcbb..c9eada38 100644 --- a/src/link.jl +++ b/src/link.jl @@ -27,13 +27,15 @@ end function _select(t::RowType) refs = SQLQuery[] + t.visible || return refs for (f, ft) in t.fields if ft isa ScalarType + ft.visible || continue push!(refs, Get(f)) else nested_refs = _select(ft) for nested_ref in nested_refs - push!(refs, Nested(over = nested_ref, name = f)) + push!(refs, Nested(name = f, tail = nested_ref)) end end end diff --git a/src/nodes.jl b/src/nodes.jl index 5fc547c0..e16e70ef 100644 --- a/src/nodes.jl +++ b/src/nodes.jl @@ -54,17 +54,19 @@ terminal(q::SQLQuery) = Chain(q′, q) = convert(SQLQuery, q)(q′) -label(q::SQLQuery) = - @something label(q.head) label(q.tail) +function label(q::SQLQuery; default = :_) + l = label(q.head) + l !== nothing ? l : label(q.tail; default) +end label(n::AbstractSQLNode) = nothing -label(::Nothing) = - :_ +label(::Nothing; default = :_) = + default -label(q) = - label(convert(SQLQuery, q)) +label(q; default = :_) = + label(convert(SQLQuery, q); default) # A variant of SQLQuery for assembling a chain of identifiers. @@ -922,6 +924,7 @@ include("nodes/order.jl") include("nodes/over.jl") include("nodes/partition.jl") include("nodes/select.jl") +include("nodes/show.jl") include("nodes/sort.jl") include("nodes/variable.jl") include("nodes/where.jl") diff --git a/src/nodes/internal.jl b/src/nodes/internal.jl index 931afb97..36221e3e 100644 --- a/src/nodes/internal.jl +++ b/src/nodes/internal.jl @@ -302,7 +302,7 @@ PrettyPrinting.quoteof(n::FunSQLMacroNode, ctx::QuoteContext) = Expr(:macrocall, Symbol("@funsql"), n.line, !ctx.limit ? n.ex : :…) label(n::FunSQLMacroNode) = - label(n.query) + label(n.query, default = nothing) # Unwrap @funsql macro when displaying the query. diff --git a/src/nodes/show.jl b/src/nodes/show.jl new file mode 100644 index 00000000..f0176ec7 --- /dev/null +++ b/src/nodes/show.jl @@ -0,0 +1,38 @@ +# Show/Hide nodes + +mutable struct ShowNode <: TabularNode + names::Vector{Symbol} + visible::Bool + label_map::FunSQL.OrderedDict{Symbol, Int} + + function ShowNode(; names = [], visible = true, label_map = nothing) + if label_map !== nothing + new(names, visible, label_map) + else + n = new(names, visible, FunSQL.OrderedDict{Symbol, Int}()) + for (i, name) in enumerate(n.names) + if name in keys(n.label_map) + err = FunSQL.DuplicateLabelError(name, path = SQLQuery[n]) + throw(err) + end + n.label_map[name] = i + end + n + end + end +end + +ShowNode(names...; visible = true) = + ShowNode(names = Symbol[names...], visible = visible) + +const Show = SQLQueryCtor{ShowNode}(:Show) + +Hide(args...; kws...) = + Show(args...; kws..., visible = false) + +const funsql_show = Show +const funsql_hide = Hide + +function FunSQL.PrettyPrinting.quoteof(n::ShowNode, ctx::QuoteContext) + Expr(:call, n.visible ? :Show : :Hide, quoteof(n.names, ctx)...) +end diff --git a/src/resolve.jl b/src/resolve.jl index bcdb6e21..1e901673 100644 --- a/src/resolve.jl +++ b/src/resolve.jl @@ -423,7 +423,7 @@ end function resolve(n::JoinNode, ctx) if n.swap - ctx′ = ResolveContext(Ctx, tail = n.joinee) + ctx′ = ResolveContext(ctx, tail = n.joinee) return resolve(JoinNode(joinee = ctx.tail, on = n.on, left = n.right, right = n.left, optional = n.optional), ctx′) end tail′ = resolve(ctx) @@ -506,6 +506,33 @@ function resolve(n::SelectNode, ctx) Resolved(RowType(fields), tail = q′) end +function resolve(n::ShowNode, ctx) + tail′ = resolve(ctx) + t = row_type(tail′) + for name in n.names + ft = get(t.fields, name, EmptyType()) + if ft isa EmptyType + throw( + ReferenceError( + REFERENCE_ERROR_TYPE.UNDEFINED_NAME, + name = name, + path = get_path(ctx))) + end + end + fields = FieldTypeMap() + for (f, ft) in t.fields + if f in keys(n.label_map) + if ft isa ScalarType + ft = ScalarType(visible = n.visible) + else + ft = RowType(ft.fields, ft.group, visible = n.visible) + end + end + fields[f] = ft + end + Resolved(RowType(fields, t.group, visible = t.visible), tail = tail′) +end + function resolve_scalar(n::SortNode, ctx) tail′ = resolve_scalar(ctx) q′ = Sort(value = n.value, nulls = n.nulls, tail = tail′) diff --git a/src/types.jl b/src/types.jl index 856821ed..e04cfec6 100644 --- a/src/types.jl +++ b/src/types.jl @@ -13,17 +13,27 @@ PrettyPrinting.quoteof(::EmptyType) = Expr(:call, nameof(EmptyType)) struct ScalarType <: AbstractSQLType + visible::Bool + + ScalarType(; visible = true) = + new(visible) end -PrettyPrinting.quoteof(::ScalarType) = - Expr(:call, nameof(ScalarType)) +function PrettyPrinting.quoteof(t::ScalarType) + ex = Expr(:call, nameof(ScalarType)) + if !t.visible + push!(ex.args, Expr(:kw, :visible, t.visible)) + end + ex +end struct RowType <: AbstractSQLType fields::OrderedDict{Symbol, Union{ScalarType, RowType}} group::Union{EmptyType, RowType} + visible::Bool - RowType(fields, group = EmptyType()) = - new(fields, group) + RowType(fields, group = EmptyType(); visible = true) = + new(fields, group, visible) end const FieldTypeMap = OrderedDict{Symbol, Union{ScalarType, RowType}} @@ -43,6 +53,9 @@ function PrettyPrinting.quoteof(t::RowType) if !(t.group isa EmptyType) push!(ex.args, Expr(:kw, :group, quoteof(t.group))) end + if !t.visible + push!(ex.args, Expr(:kw, :visible, t.visible)) + end ex end @@ -54,8 +67,8 @@ const EMPTY_ROW = RowType() Base.intersect(::AbstractSQLType, ::AbstractSQLType) = EmptyType() -Base.intersect(::ScalarType, ::ScalarType) = - ScalarType() +Base.intersect(t1::ScalarType, t2::ScalarType) = + ScalarType(visible = t1.visible || t2.visible) function Base.intersect(t1::RowType, t2::RowType) if t1 === t2 @@ -71,7 +84,7 @@ function Base.intersect(t1::RowType, t2::RowType) end end group = intersect(t1.group, t2.group) - RowType(fields, group) + RowType(fields, group, visible = t1.visible || t2.visible) end @@ -98,5 +111,8 @@ function Base.issubset(t1::RowType, t2::RowType) if !issubset(t1.group, t2.group) return false end + if !t1.visible && t2.visible + return false + end return true end From ff18dc057054eb2733b186577950ed1f757e571a Mon Sep 17 00:00:00 2001 From: Kyrylo Simonov Date: Wed, 5 Nov 2025 22:58:24 -0600 Subject: [PATCH 5/7] Replace show()/hide() with "private" parameter for define()/join() --- src/FunSQL.jl | 2 -- src/catalogs.jl | 23 +++++++++------ src/link.jl | 3 +- src/nodes.jl | 1 - src/nodes/define.jl | 20 ++++++++----- src/nodes/hide.jl | 42 --------------------------- src/nodes/into.jl | 19 +++++++----- src/nodes/join.jl | 24 +++++++++------- src/nodes/show.jl | 38 ------------------------ src/resolve.jl | 70 ++++++++++++++++++++++----------------------- src/types.jl | 41 ++++++++++++-------------- 11 files changed, 107 insertions(+), 176 deletions(-) delete mode 100644 src/nodes/hide.jl delete mode 100644 src/nodes/show.jl diff --git a/src/FunSQL.jl b/src/FunSQL.jl index 56e4378a..618c6610 100644 --- a/src/FunSQL.jl +++ b/src/FunSQL.jl @@ -54,7 +54,6 @@ export funsql_from, funsql_fun, funsql_group, - funsql_hide, funsql_highlight, funsql_in, funsql_into, @@ -84,7 +83,6 @@ export funsql_rank, funsql_row_number, funsql_select, - funsql_show, funsql_sort, funsql_sum, funsql_with diff --git a/src/catalogs.jl b/src/catalogs.jl index 0fbd446a..1fe7f2e6 100644 --- a/src/catalogs.jl +++ b/src/catalogs.jl @@ -31,22 +31,24 @@ _metadata_get(dict::SQLMetadata, key::Union{Symbol, AbstractString}, default; st end """ - SQLColumn(; name, metadata = nothing) - SQLColumn(name; metadata = nothing) + SQLColumn(; name, private = false, metadata = nothing) + SQLColumn(name; private = false, metadata = nothing) `SQLColumn` represents a column with the given `name` and optional `metadata`. +If `private` is `true`, the column is excluded from the default query output. """ struct SQLColumn name::Symbol + private::Bool metadata::SQLMetadata - function SQLColumn(; name::Union{Symbol, AbstractString}, metadata = nothing) - new(Symbol(name), _metadata(metadata)) + function SQLColumn(; name::Union{Symbol, AbstractString}, private = false, metadata = nothing) + new(Symbol(name), private, _metadata(metadata)) end end -SQLColumn(name; metadata = nothing) = - SQLColumn(name = name, metadata = metadata) +SQLColumn(name; private = false, metadata = nothing) = + SQLColumn(; name, private, metadata) Base.show(io::IO, col::SQLColumn) = print(io, quoteof(col, limit = true)) @@ -56,6 +58,9 @@ Base.show(io::IO, ::MIME"text/plain", col::SQLColumn) = function PrettyPrinting.quoteof(col::SQLColumn; limit::Bool = false) ex = Expr(:call, nameof(SQLColumn), QuoteNode(col.name)) + if col.private + push(ex.args, Expr(:kw, :private, col.private)) + end if !isempty(col.metadata) push!(ex.args, Expr(:kw, :metadata, limit ? :… : quoteof(reverse!(collect(col.metadata))))) end @@ -122,10 +127,10 @@ struct SQLTable <: AbstractDict{Symbol, SQLColumn} end SQLTable(name; qualifiers = Symbol[], columns, metadata = nothing) = - SQLTable(qualifiers = qualifiers, name = name, columns = columns, metadata = metadata) + SQLTable(; qualifiers, name, columns, metadata) SQLTable(name, columns...; qualifiers = Symbol[], metadata = nothing) = - SQLTable(qualifiers = qualifiers, name = name, columns = [columns...], metadata = metadata) + SQLTable(; qualifiers, name, columns = [columns...], metadata) _column_map(columns::OrderedDict{Symbol, SQLColumn}) = columns @@ -280,7 +285,7 @@ struct SQLCatalog <: AbstractDict{Symbol, SQLTable} end SQLCatalog(tables...; dialect = :default, cache = default_cache_maxsize, metadata = nothing) = - SQLCatalog(tables = tables, dialect = dialect, cache = cache, metadata = metadata) + SQLCatalog(; tables, dialect, cache, metadata) _table_map(tables::Dict{Symbol, SQLTable}) = tables diff --git a/src/link.jl b/src/link.jl index c9eada38..8bb14e13 100644 --- a/src/link.jl +++ b/src/link.jl @@ -27,10 +27,9 @@ end function _select(t::RowType) refs = SQLQuery[] - t.visible || return refs for (f, ft) in t.fields + !(f in t.private_fields) || continue if ft isa ScalarType - ft.visible || continue push!(refs, Get(f)) else nested_refs = _select(ft) diff --git a/src/nodes.jl b/src/nodes.jl index e16e70ef..b8e68a13 100644 --- a/src/nodes.jl +++ b/src/nodes.jl @@ -924,7 +924,6 @@ include("nodes/order.jl") include("nodes/over.jl") include("nodes/partition.jl") include("nodes/select.jl") -include("nodes/show.jl") include("nodes/sort.jl") include("nodes/variable.jl") include("nodes/where.jl") diff --git a/src/nodes/define.jl b/src/nodes/define.jl index fab058de..3fd44dfc 100644 --- a/src/nodes/define.jl +++ b/src/nodes/define.jl @@ -4,13 +4,14 @@ struct DefineNode <: TabularNode args::Vector{SQLQuery} before::Union{Symbol, Bool} after::Union{Symbol, Bool} + private::Bool label_map::OrderedDict{Symbol, Int} - function DefineNode(; args = [], before = nothing, after = nothing, label_map = nothing) + function DefineNode(; args = [], before = nothing, after = nothing, private = false, label_map = nothing) if label_map !== nothing - n = new(args, something(before, false), something(after, false), label_map) + n = new(args, something(before, false), something(after, false), private, label_map) else - n = new(args, something(before, false), something(after, false), OrderedDict{Symbol, Int}()) + n = new(args, something(before, false), something(after, false), private, OrderedDict{Symbol, Int}()) populate_label_map!(n) end if (n.before isa Symbol || n.before) && (n.after isa Symbol || n.after) @@ -20,12 +21,12 @@ struct DefineNode <: TabularNode end end -DefineNode(args...; before = nothing, after = nothing) = - DefineNode(args = SQLQuery[args...], before = before, after = after) +DefineNode(args...; before = nothing, after = nothing, private = false) = + DefineNode(args = SQLQuery[args...], before = before, after = after, private = private) """ - Define(; args = [], before = nothing, after = nothing, tail = nothing) - Define(args...; before = nothing, after = nothing, tail = nothing) + Define(; args = [], before = nothing, after = nothing, private = false, tail = nothing) + Define(args...; before = nothing, after = nothing, private = false, tail = nothing) The `Define` node adds or replaces output columns. @@ -35,6 +36,8 @@ both new and replaced columns at the end (after a specified column). Alternatively, set `before = true` (`before = `) to add both new and replaced columns at the front (before the specified column). +If `private` is set, the columns will be excluded from the query output. + # Examples *Show patients who are at least 16 years old.* @@ -90,5 +93,8 @@ function PrettyPrinting.quoteof(n::DefineNode, ctx::QuoteContext) if n.after !== false push!(ex.args, Expr(:kw, :after, n.after isa Symbol ? QuoteNode(n.after) : n.after)) end + if n.private !== false + push!(ex.args, Expr(:kw, :private, n.private)) + end ex end diff --git a/src/nodes/hide.jl b/src/nodes/hide.jl deleted file mode 100644 index 36ced02c..00000000 --- a/src/nodes/hide.jl +++ /dev/null @@ -1,42 +0,0 @@ -# Hide node - -mutable struct HideNode <: TabularNode - over::Union{SQLNode, Nothing} - names::Vector{Symbol} - label_map::FunSQL.OrderedDict{Symbol, Int} - - function HideNode(; over = nothing, names = [], label_map = nothing) - if label_map !== nothing - new(over, names, label_map) - else - n = new(over, names, FunSQL.OrderedDict{Symbol, Int}()) - for (i, name) in enumerate(n.names) - if name in keys(n.label_map) - err = FunSQL.DuplicateLabelError(name, path = [n]) - throw(err) - end - n.label_map[name] = i - end - n - end - end -end - -HideNode(names...; over = nothing) = - HideNode(over = over, names = Symbol[names...]) - -Hide(args...; kws...) = - HideNode(args...; kws...) |> SQLNode - -const funsql_hide = Hide - -dissect(scr::Symbol, ::typeof(Hide), pats::Vector{Any}) = - dissect(scr, HideNode, pats) - -function FunSQL.PrettyPrinting.quoteof(n::HideNode, ctx::FunSQL.QuoteContext) - ex = Expr(:call, nameof(Hide), quoteof(n.names, ctx)...) - if n.over !== nothing - ex = Expr(:call, :|>, FunSQL.quoteof(n.over, ctx), ex) - end - ex -end diff --git a/src/nodes/into.jl b/src/nodes/into.jl index f421ea1e..655652c0 100644 --- a/src/nodes/into.jl +++ b/src/nodes/into.jl @@ -2,17 +2,18 @@ mutable struct IntoNode <: TabularNode name::Symbol + private::Bool - IntoNode(; name::Union{Symbol, AbstractString}) = - new(Symbol(name)) + IntoNode(; name::Union{Symbol, AbstractString}, private::Bool = false) = + new(Symbol(name), private) end -IntoNode(name) = - IntoNode(; name) +IntoNode(name; private = false) = + IntoNode(; name, private) """ - Into(; name, tail = nothing) - Into(name; tail = nothing) + Into(; name, private = false, tail = nothing) + Into(name; private = false, tail = nothing) `Into` wraps output columns in a nested record. """ @@ -21,7 +22,11 @@ const Into = SQLQueryCtor{IntoNode}(:Into) const funsql_into = Into function PrettyPrinting.quoteof(n::IntoNode, ctx::QuoteContext) - Expr(:call, :Into, quoteof(n.name)) + ex = Expr(:call, :Into, quoteof(n.name)) + if n.private + push!(ex.args, Expr(:kw, :private, n.private)) + end + ex end label(n::IntoNode) = diff --git a/src/nodes/join.jl b/src/nodes/join.jl index 372d2272..027914f3 100644 --- a/src/nodes/join.jl +++ b/src/nodes/join.jl @@ -7,21 +7,22 @@ mutable struct JoinNode <: TabularNode right::Bool optional::Bool swap::Bool + private::Bool - JoinNode(; joinee, on, left = false, right = false, optional = false, swap = false) = - new(joinee, on, left, right, optional, swap) + JoinNode(; joinee, on, left = false, right = false, optional = false, swap = false, private = false) = + new(joinee, on, left, right, optional, swap, private) end -JoinNode(joinee; on, left = false, right = false, optional = false, swap = false) = - JoinNode(; joinee, on, left, right, optional, swap) +JoinNode(joinee; on, left = false, right = false, optional = false, swap = false, private = false) = + JoinNode(; joinee, on, left, right, optional, swap, private) -JoinNode(joinee, on; left = false, right = false, optional = false, swap = false) = - JoinNode(; joinee, on, left, right, optional, swap) +JoinNode(joinee, on; left = false, right = false, optional = false, swap = false, private = false) = + JoinNode(; joinee, on, left, right, optional, swap, private) """ - Join(; joinee, on, left = false, right = false, optional = false, swap = false) - Join(joinee; on, left = false, right = false, optional = false, swap = false) - Join(joinee, on; left = false, right = false, optional = false, swap = false) + Join(; joinee, on, left = false, right = false, optional = false, swap = false, private = false) + Join(joinee; on, left = false, right = false, optional = false, swap = false, private = false) + Join(joinee, on; left = false, right = false, optional = false, swap = false, private = false) `Join` correlates two input datasets. @@ -106,6 +107,9 @@ function PrettyPrinting.quoteof(n::JoinNode, ctx::QuoteContext) if n.swap push!(ex.args, Expr(:kw, :swap, n.swap)) end + if n.private + push!(ex.args, Expr(:kw, :private, n.private)) + end else push!(ex.args, :…) end @@ -113,4 +117,4 @@ function PrettyPrinting.quoteof(n::JoinNode, ctx::QuoteContext) end label(n::JoinNode) = - n.swap ? label(n.joinee) : label(n.over) + n.swap ? label(n.joinee) : nothing diff --git a/src/nodes/show.jl b/src/nodes/show.jl deleted file mode 100644 index f0176ec7..00000000 --- a/src/nodes/show.jl +++ /dev/null @@ -1,38 +0,0 @@ -# Show/Hide nodes - -mutable struct ShowNode <: TabularNode - names::Vector{Symbol} - visible::Bool - label_map::FunSQL.OrderedDict{Symbol, Int} - - function ShowNode(; names = [], visible = true, label_map = nothing) - if label_map !== nothing - new(names, visible, label_map) - else - n = new(names, visible, FunSQL.OrderedDict{Symbol, Int}()) - for (i, name) in enumerate(n.names) - if name in keys(n.label_map) - err = FunSQL.DuplicateLabelError(name, path = SQLQuery[n]) - throw(err) - end - n.label_map[name] = i - end - n - end - end -end - -ShowNode(names...; visible = true) = - ShowNode(names = Symbol[names...], visible = visible) - -const Show = SQLQueryCtor{ShowNode}(:Show) - -Hide(args...; kws...) = - Show(args...; kws..., visible = false) - -const funsql_show = Show -const funsql_hide = Hide - -function FunSQL.PrettyPrinting.quoteof(n::ShowNode, ctx::QuoteContext) - Expr(:call, n.visible ? :Show : :Hide, quoteof(n.names, ctx)...) -end diff --git a/src/resolve.jl b/src/resolve.jl index 1e901673..2628e7f9 100644 --- a/src/resolve.jl +++ b/src/resolve.jl @@ -263,16 +263,28 @@ function resolve(n::DefineNode, ctx) end end end + private_fields = copy(t.private_fields) + for l in keys(n.label_map) + if n.private + push!(private_fields, l) + else + delete!(private_fields, l) + end + end q′ = Define(args = args′, label_map = n.label_map, tail = tail′) - Resolved(RowType(fields, t.group), tail = q′) + Resolved(RowType(fields, t.group, private_fields), tail = q′) end function RowType(table::SQLTable) fields = FieldTypeMap() - for f in keys(table.columns) + private_fields = Set{Symbol}() + for (f, c) in table.columns fields[f] = ScalarType() + if c.private + push!(private_fields, f) + end end - RowType(fields) + RowType(fields, EmptyType(), private_fields) end function resolve(n::FromNode, ctx) @@ -390,8 +402,12 @@ function resolve(n::GroupNode, ctx) fields[n.name] = RowType(FieldTypeMap(), group) group = EmptyType() end + private_fields = Set{Symbol}() + if n.name !== nothing + push!(private_fields, n.name) + end q′ = Group(by = by′, sets = n.sets, label_map = n.label_map, tail = tail′) - Resolved(RowType(fields, group), tail = q′) + Resolved(RowType(fields, group, private_fields), tail = q′) end resolve(::HighlightNode, ctx) = @@ -404,7 +420,8 @@ function resolve(n::IntoNode, ctx) tail′ = resolve(ctx) t = row_type(tail′) q′ = Into(name = n.name, tail = tail′) - Resolved(RowType(FieldTypeMap(n.name => t)), tail = q′) + t′ = RowType(FieldTypeMap(n.name => t), EmptyType(), n.private ? Set([n.name]) : Set{Symbol}()) + Resolved(t′, tail = q′) end function resolve(n::IterateNode, ctx) @@ -424,7 +441,7 @@ end function resolve(n::JoinNode, ctx) if n.swap ctx′ = ResolveContext(ctx, tail = n.joinee) - return resolve(JoinNode(joinee = ctx.tail, on = n.on, left = n.right, right = n.left, optional = n.optional), ctx′) + return resolve(JoinNode(joinee = ctx.tail, on = n.on, left = n.right, right = n.left, optional = n.optional, private = n.private), ctx′) end tail′ = resolve(ctx) lt = row_type(tail′) @@ -437,7 +454,13 @@ function resolve(n::JoinNode, ctx) end fields[name] = rt group = lt.group - t = RowType(fields, group) + private_fields = copy(lt.private_fields) + if n.private + push!(private_fields, name) + else + delete!(private_fields, name) + end + t = RowType(fields, group, private_fields) on′ = resolve_scalar(n.on, ctx, t) q′ = RoutedJoin(joinee = joinee′, on = on′, name = name, left = n.left, right = n.right, optional = n.optional, tail = tail′) Resolved(t, tail = q′) @@ -490,8 +513,12 @@ function resolve(n::PartitionNode, ctx) end fields[n.name] = RowType(FieldTypeMap(), t) end + private_fields = copy(t.private_fields) + if n.name !== nothing + push!(private_fields, n.name) + end q′ = Partition(by = by′, order_by = order_by′, frame = n.frame, name = n.name, tail = tail′) - Resolved(RowType(fields, group), tail = q′) + Resolved(RowType(fields, group, private_fields), tail = q′) end function resolve(n::SelectNode, ctx) @@ -506,33 +533,6 @@ function resolve(n::SelectNode, ctx) Resolved(RowType(fields), tail = q′) end -function resolve(n::ShowNode, ctx) - tail′ = resolve(ctx) - t = row_type(tail′) - for name in n.names - ft = get(t.fields, name, EmptyType()) - if ft isa EmptyType - throw( - ReferenceError( - REFERENCE_ERROR_TYPE.UNDEFINED_NAME, - name = name, - path = get_path(ctx))) - end - end - fields = FieldTypeMap() - for (f, ft) in t.fields - if f in keys(n.label_map) - if ft isa ScalarType - ft = ScalarType(visible = n.visible) - else - ft = RowType(ft.fields, ft.group, visible = n.visible) - end - end - fields[f] = ft - end - Resolved(RowType(fields, t.group, visible = t.visible), tail = tail′) -end - function resolve_scalar(n::SortNode, ctx) tail′ = resolve_scalar(ctx) q′ = Sort(value = n.value, nulls = n.nulls, tail = tail′) diff --git a/src/types.jl b/src/types.jl index e04cfec6..24d0c369 100644 --- a/src/types.jl +++ b/src/types.jl @@ -13,27 +13,21 @@ PrettyPrinting.quoteof(::EmptyType) = Expr(:call, nameof(EmptyType)) struct ScalarType <: AbstractSQLType - visible::Bool - - ScalarType(; visible = true) = - new(visible) + ScalarType() = + new() end function PrettyPrinting.quoteof(t::ScalarType) - ex = Expr(:call, nameof(ScalarType)) - if !t.visible - push!(ex.args, Expr(:kw, :visible, t.visible)) - end - ex + Expr(:call, nameof(ScalarType)) end struct RowType <: AbstractSQLType fields::OrderedDict{Symbol, Union{ScalarType, RowType}} group::Union{EmptyType, RowType} - visible::Bool + private_fields::Set{Symbol} - RowType(fields, group = EmptyType(); visible = true) = - new(fields, group, visible) + RowType(fields, group = EmptyType(), private_fields = Set{Symbol}()) = + new(fields, group, private_fields) end const FieldTypeMap = OrderedDict{Symbol, Union{ScalarType, RowType}} @@ -42,8 +36,8 @@ const GroupType = Union{EmptyType, RowType} RowType() = RowType(FieldTypeMap()) -RowType(fields::Pair{Symbol, <:AbstractSQLType}...; group = EmptyType()) = - RowType(FieldTypeMap(fields), group) +RowType(fields::Pair{Symbol, <:AbstractSQLType}...; group = EmptyType(), private_fields = Set{Symbol}()) = + RowType(FieldTypeMap(fields), group, private_fields) function PrettyPrinting.quoteof(t::RowType) ex = Expr(:call, nameof(RowType)) @@ -53,8 +47,8 @@ function PrettyPrinting.quoteof(t::RowType) if !(t.group isa EmptyType) push!(ex.args, Expr(:kw, :group, quoteof(t.group))) end - if !t.visible - push!(ex.args, Expr(:kw, :visible, t.visible)) + if !isempty(t.private_fields) + push!(ex.args, Expr(:kw, :private_fields, t.private_fields)) end ex end @@ -67,24 +61,28 @@ const EMPTY_ROW = RowType() Base.intersect(::AbstractSQLType, ::AbstractSQLType) = EmptyType() -Base.intersect(t1::ScalarType, t2::ScalarType) = - ScalarType(visible = t1.visible || t2.visible) +Base.intersect(::ScalarType, ::ScalarType) = + ScalarType() function Base.intersect(t1::RowType, t2::RowType) if t1 === t2 return t1 end fields = FieldTypeMap() + private_fields = Set{Symbol}() for f in keys(t1.fields) if f in keys(t2.fields) t = intersect(t1.fields[f], t2.fields[f]) if !isa(t, EmptyType) fields[f] = t + if f in t1.private_fields && f in t2.private_fields + push!(private_fields, f) + end end end end group = intersect(t1.group, t2.group) - RowType(fields, group, visible = t1.visible || t2.visible) + RowType(fields, group, private_fields) end @@ -104,15 +102,12 @@ function Base.issubset(t1::RowType, t2::RowType) return true end for f in keys(t1.fields) - if !(f in keys(t2.fields) && issubset(t1.fields[f], t2.fields[f])) + if !(f in keys(t2.fields) && issubset(t1.fields[f], t2.fields[f]) && (!(f in t1.private_fields) || f in t2.private_fields)) return false end end if !issubset(t1.group, t2.group) return false end - if !t1.visible && t2.visible - return false - end return true end From 547dc258d4beb1236ed2132507aa7b930a73d9db Mon Sep 17 00:00:00 2001 From: Kyrylo Simonov Date: Fri, 28 Nov 2025 13:12:44 -0600 Subject: [PATCH 6/7] Prefix output columns from a JOIN --- Project.toml | 4 +++- src/FunSQL.jl | 1 + src/clauses/internal.jl | 10 +++++----- src/link.jl | 33 ++++++++++++++++++++++++--------- src/nodes/internal.jl | 8 ++++++-- src/serialize.jl | 4 ++-- src/strings.jl | 22 ++++++++++++---------- src/translate.jl | 23 +++++++++++------------ 8 files changed, 64 insertions(+), 41 deletions(-) diff --git a/Project.toml b/Project.toml index 2d76c69e..bea4bf08 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "FunSQL" uuid = "cf6cc811-59f4-4a10-b258-a8547a8f6407" -authors = ["Kirill Simonov ", "Clark C. Evans "] version = "0.15.0" +authors = ["Kirill Simonov ", "Clark C. Evans "] [deps] DBInterface = "a10d1c49-ce27-4219-8d33-6db1a4562965" @@ -11,6 +11,7 @@ LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" PrettyPrinting = "54e16d92-306c-5ea0-a30b-337be88ac337" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +URIs = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" [compat] DBInterface = "2.5" @@ -19,4 +20,5 @@ LRUCache = "1.3" OrderedCollections = "1.4" PrettyPrinting = "0.3.2, 0.4" Tables = "1.6" +URIs = "1.6.1" julia = "1.10" diff --git a/src/FunSQL.jl b/src/FunSQL.jl index 618c6610..1f99cf5a 100644 --- a/src/FunSQL.jl +++ b/src/FunSQL.jl @@ -95,6 +95,7 @@ using Tables using DBInterface using LRUCache using DataAPI +using URIs const SQLLiteralType = Union{Missing, Bool, Number, AbstractString, Dates.AbstractTime} diff --git a/src/clauses/internal.jl b/src/clauses/internal.jl index 8eea0184..6f069dbe 100644 --- a/src/clauses/internal.jl +++ b/src/clauses/internal.jl @@ -4,10 +4,10 @@ struct WithContextClause <: AbstractSQLClause dialect::SQLDialect - columns::Union{Vector{SQLColumn}, Nothing} + table::Union{SQLTable, Nothing} - WithContextClause(; dialect, columns = nothing) = - new(dialect, columns) + WithContextClause(; dialect, table = nothing) = + new(dialect, table) end const WITH_CONTEXT = SQLSyntaxCtor{WithContextClause}(:WITH_CONTEXT) @@ -17,8 +17,8 @@ function PrettyPrinting.quoteof(c::WithContextClause, ctx::QuoteContext) if c.dialect !== default_dialect push!(ex.args, Expr(:kw, :dialect, quoteof(c.dialect))) end - if c.columns !== nothing - push!(ex.args, Expr(:kw, :columns, Expr(:vect, Any[quoteof(col) for col in c.columns]...))) + if c.table !== nothing + push!(ex.args, Expr(:kw, :table, quoteof(c.table))) end ex end diff --git a/src/link.jl b/src/link.jl index 8bb14e13..21cb9a0b 100644 --- a/src/link.jl +++ b/src/link.jl @@ -26,28 +26,43 @@ struct LinkContext end function _select(t::RowType) - refs = SQLQuery[] + args = Pair{SQLColumn, SQLQuery}[] for (f, ft) in t.fields !(f in t.private_fields) || continue + lbl = escapefield(f) if ft isa ScalarType - push!(refs, Get(f)) + push!(args, SQLColumn(lbl) => Get(f)) else - nested_refs = _select(ft) - for nested_ref in nested_refs - push!(refs, Nested(name = f, tail = nested_ref)) + nested_args = _select(ft) + for (nested_col, nested_ref) in nested_args + nested_lbl = Symbol(lbl, '.', nested_col.name) + nested_col = SQLColumn(nested_lbl) + push!(args, nested_col => Nested(name = f, tail = nested_ref)) end end end - refs + args end +escapefield(f::Symbol) = + Symbol(escapefield(string(f))) + +escapefield(f) = + escapeuri(f, c -> c == ' ' || c != '.' && URIs.issafe(c)) + function link(q::SQLQuery) @dissect(q, (local tail) |> WithContext(catalog = (local catalog))) || throw(IllFormedError()) ctx = LinkContext(catalog) t = row_type(tail) - refs = _select(t) + col_refs = _select(t) + columns = first.(col_refs) + if isempty(columns) + columns = [SQLColumn(:_)] + end + table = SQLTable(escapefield(label(q)), columns = columns) + refs = last.(col_refs) tail′ = Linked(refs, tail = link(dismantle(tail, ctx), ctx, refs)) - WithContext(tail = tail′, catalog = catalog, defs = ctx.defs) + WithContext(tail = tail′, catalog = catalog, table = table, defs = ctx.defs) end function dismantle(q::SQLQuery, ctx) @@ -566,7 +581,7 @@ end function gather!(n::IsolatedNode, ctx) def = ctx.defs[n.idx] !@dissect(def, Linked()) || return - refs = _select(n.type) + refs = last.(_select(n.type)) if !isempty(refs) refs = refs[1:1] end diff --git a/src/nodes/internal.jl b/src/nodes/internal.jl index 36221e3e..5ec0bd44 100644 --- a/src/nodes/internal.jl +++ b/src/nodes/internal.jl @@ -4,9 +4,10 @@ struct WithContextNode <: AbstractSQLNode catalog::SQLCatalog defs::Vector{SQLQuery} + table::Union{SQLTable, Nothing} - WithContextNode(; catalog = SQLCatalog(), defs = SQLQuery[]) = - new(catalog, defs) + WithContextNode(; catalog = SQLCatalog(), defs = SQLQuery[], table = nothing) = + new(catalog, defs, table) end const WithContext = SQLQueryCtor{WithContextNode}(:WithContext) @@ -17,6 +18,9 @@ function PrettyPrinting.quoteof(n::WithContextNode, ctx::QuoteContext) if !isempty(n.defs) push!(ex.args, Expr(:kw, :defs, Expr(:vect, Any[quoteof(def, ctx) for def in n.defs]...))) end + if n.table !== nothing + push!(ex.args, Expr(:kw, :table, quoteof(n.table))) + end ex end diff --git a/src/serialize.jl b/src/serialize.jl index 16d8a8b5..02deb302 100644 --- a/src/serialize.jl +++ b/src/serialize.jl @@ -13,11 +13,11 @@ mutable struct SerializeContext <: IO end function serialize(s::SQLSyntax) - @dissect(s, WITH_CONTEXT(tail = (local s′), dialect = (local dialect), columns = (local columns))) || throw(IllFormedError()) + @dissect(s, WITH_CONTEXT(tail = (local s′), dialect = (local dialect), table = (local table))) || throw(IllFormedError()) ctx = SerializeContext(dialect, s′) serialize!(ctx) raw = String(take!(ctx.io)) - SQLString(raw, columns = columns, vars = ctx.vars) + SQLString(raw, table = table, vars = ctx.vars) end Base.write(ctx::SerializeContext, octet::UInt8) = diff --git a/src/strings.jl b/src/strings.jl index 122a8848..c0d2cce6 100644 --- a/src/strings.jl +++ b/src/strings.jl @@ -5,7 +5,8 @@ Serialized SQL query. -Parameter `columns` is a vector describing the output columns. +Parameter `table` is an optional `SQLTable` object describing the output of +the query. Parameter `vars` is a vector of query parameters (created with [`Var`](@ref)) in the order they are expected by the `DBInterface.execute()` function. @@ -55,11 +56,11 @@ SQLString(\""" """ struct SQLString <: AbstractString raw::String - columns::Union{Vector{SQLColumn}, Nothing} + table::Union{SQLTable, Nothing} vars::Vector{Symbol} - SQLString(raw; columns = nothing, vars = Symbol[]) = - new(raw, columns, vars) + SQLString(raw; table = nothing, vars = Symbol[]) = + new(raw, table, vars) end Base.ncodeunits(sql::SQLString) = @@ -88,8 +89,8 @@ Base.write(io::IO, sql::SQLString) = function PrettyPrinting.quoteof(sql::SQLString) ex = Expr(:call, nameof(SQLString), sql.raw) - if sql.columns !== nothing - push!(ex.args, Expr(:kw, :columns, Expr(:vect, Any[quoteof(col) for col in sql.columns]...))) + if sql.table !== nothing + push!(ex.args, Expr(:kw, :table, quoteof(sql.table))) end if !isempty(sql.vars) push!(ex.args, Expr(:kw, :vars, quoteof(sql.vars))) @@ -100,10 +101,11 @@ end function Base.show(io::IO, sql::SQLString) print(io, "SQLString(") show(io, sql.raw) - if sql.columns !== nothing - print(io, ", columns = ") - l = length(sql.columns) - print(io, l == 0 ? "[]" : l == 1 ? "[…1 column…]" : "[…$l columns…]") + if sql.table !== nothing + print(io, ", table = SQLTable(") + show(io, sql.table.name) + l = length(sql.table.columns) + print(io, ", ", l == 0 ? "[]" : l == 1 ? "[…1 column…]" : "[…$l columns…]", ")") end if !isempty(sql.vars) print(io, ", vars = ") diff --git a/src/translate.jl b/src/translate.jl index 0e4a16fb..2b124019 100644 --- a/src/translate.jl +++ b/src/translate.jl @@ -39,10 +39,13 @@ function complete(a::Assemblage) end # Add a SELECT clause aligned with the exported references. -function complete_aligned(a::Assemblage, ctx) +function complete_aligned(a::Assemblage, ctx, expected_names = nothing) + names = collect(keys(a.cols)) + expected_names = something(expected_names, names) aligned = length(a.cols) == length(ctx.refs) && - all(a.repl[ref] === name for (name, ref) in zip(keys(a.cols), ctx.refs)) + all(a.repl[ref] === name for (name, ref) in zip(names, ctx.refs)) && + names == expected_names !aligned || return complete(a) if !@dissect(a.syntax, SELECT() || UNION()) alias = nothing @@ -54,9 +57,9 @@ function complete_aligned(a::Assemblage, ctx) subs = make_subs(a, alias) repl = Dict{SQLQuery, Symbol}() cols = OrderedDict{Symbol, SQLSyntax}() - for ref in ctx.refs - name = repl[ref] = a.repl[ref] - cols[name] = subs[ref] + for (expected_name, ref) in zip(expected_names, ctx.refs) + cols[expected_name] = subs[ref] + repl[ref] = expected_name end a′ = Assemblage(a.name, syntax, repl = repl, cols = cols) complete(a′) @@ -211,15 +214,11 @@ function allocate_alias(ctx::TranslateContext, alias::Symbol) end function translate(q::SQLQuery) - @dissect(q, (local q′) |> Linked(refs = (local refs)) |> WithContext(catalog = (local catalog), defs = (local defs))) || throw(IllFormedError()) + @dissect(q, (local q′) |> Linked(refs = (local refs)) |> WithContext(catalog = (local catalog), table = (local table), defs = (local defs))) || throw(IllFormedError()) ctx = TranslateContext(catalog = catalog, defs = defs) ctx′ = TranslateContext(ctx, refs = refs) base = assemble(q′, ctx′) - columns = nothing - if !isempty(refs) - columns = [SQLColumn(base.repl[ref]) for ref in refs] - end - c = complete_aligned(base, ctx′) + c = complete_aligned(base, ctx′, collect(keys(table))) with_args = SQLSyntax[] for cte_a in ctx.ctes !cte_a.external || continue @@ -238,7 +237,7 @@ function translate(q::SQLQuery) if !isempty(with_args) c = WITH(tail = c, args = with_args, recursive = ctx.recursive[]) end - WITH_CONTEXT(tail = c, dialect = ctx.catalog.dialect, columns = columns) + WITH_CONTEXT(tail = c, dialect = ctx.catalog.dialect, table = table) end function translate(q::SQLQuery, ctx) From b177c273ed179cf905902e477bc0854c7f22e5bf Mon Sep 17 00:00:00 2001 From: Kyrylo Simonov Date: Sat, 29 Nov 2025 20:23:22 -0600 Subject: [PATCH 7/7] Allow qualifiers for literals and scalar functions --- src/link.jl | 66 ++++++++++++++++++++++++++--------------- src/nodes/function.jl | 3 -- src/nodes/literal.jl | 3 -- src/resolve.jl | 14 +++++++-- src/translate.jl | 68 +++++++++++++++++++++++++------------------ 5 files changed, 93 insertions(+), 61 deletions(-) diff --git a/src/link.jl b/src/link.jl index 21cb9a0b..85aed317 100644 --- a/src/link.jl +++ b/src/link.jl @@ -315,9 +315,22 @@ function link(n::FromTableExpressionNode, ctx) end function link(n::GroupNode, ctx) - has_aggregates = any(ref -> @dissect(ref, Agg() || Agg() |> Nested()), ctx.refs) - if !has_aggregates && isempty(n.by) - return link(FromNothing(), ctx) + krefs = SQLQuery[] + arefs = SQLQuery[] + has_aggregates = false + for ref in ctx.refs + if @dissect(ref, (local tail) |> Nested(name = (local name))) && name === n.name + gather!(tail, ctx, arefs) + has_aggregates = true + elseif @dissect(ref, Agg()) && n.name === nothing + push!(arefs, ref) + has_aggregates = true + else + push!(krefs, ref) + end + end + if isempty(arefs) && isempty(n.by) + return link(n.name !== nothing ? Into(n.name, tail = FromNothing()) : FromNothing(), ctx) end # Some group keys are added both to SELECT and to GROUP BY. # To avoid duplicate SQL, they must be evaluated in a nested subquery. @@ -330,16 +343,16 @@ function link(n::GroupNode, ctx) # Ignore `SELECT DISTINCT` case. if has_aggregates ctx′ = LinkContext(ctx, refs = refs) - for ref in ctx.refs - if (@dissect(ref, nothing |> Agg(args = (local args), filter = (local filter)) |> Nested(name = (local name))) && name === n.name) || - (@dissect(ref, nothing |> Agg(args = (local args), filter = (local filter))) && n.name === nothing) - gather!(args, ctx′) - if filter !== nothing - gather!(filter, ctx′) - end - elseif @dissect(ref, nothing |> Get(name = (local name))) && name in keys(n.label_map) - # Force evaluation in a nested subquery. - push!(refs, n.by[n.label_map[name]]) + for ref in krefs + @dissect(ref, Get(name = (local name))) && name in keys(n.label_map) || error() + # Force evaluation in a nested subquery. + push!(refs, n.by[n.label_map[name]]) + end + for ref in arefs + @dissect(ref, Agg(args = (local args), filter = (local filter))) || error() + gather!(args, ctx′) + if filter !== nothing + gather!(filter, ctx′) end end end @@ -356,12 +369,12 @@ function link(n::IntoNode, ctx) for ref in ctx.refs if @dissect(ref, (local tail) |> Nested(name = (local name))) @assert name == n.name - push!(refs, tail) + gather!(tail, ctx, refs) else error() end end - tail′ = link(ctx.tail, ctx, refs) + tail′ = Linked(refs, 0, tail = link(ctx.tail, ctx, refs)) Into(name = n.name, tail = tail′) end @@ -418,16 +431,14 @@ end function link(n::PartitionNode, ctx) refs = SQLQuery[] - imm_refs = SQLQuery[] - ctx′ = LinkContext(ctx, refs = imm_refs) + arefs = SQLQuery[] has_aggregates = false for ref in ctx.refs - if (@dissect(ref, nothing |> Agg(args = (local args), filter = (local filter)) |> Nested(name = (local name))) && name === n.name) || - (@dissect(ref, nothing |> Agg(args = (local args), filter = (local filter))) && n.name === nothing) - gather!(args, ctx′) - if filter !== nothing - gather!(filter, ctx′) - end + if @dissect(ref, (local tail) |> Nested(name = (local name))) && name === n.name + gather!(tail, ctx, arefs) + has_aggregates = true + elseif @dissect(ref, Agg()) && n.name === nothing + push!(arefs, ref) has_aggregates = true else push!(refs, ref) @@ -436,6 +447,15 @@ function link(n::PartitionNode, ctx) if !has_aggregates return link(ctx) end + imm_refs = SQLQuery[] + ctx′ = LinkContext(ctx, refs = imm_refs) + for ref in arefs + @dissect(ref, Agg(args = (local args), filter = (local filter))) || error() + gather!(args, ctx′) + if filter !== nothing + gather!(filter, ctx′) + end + end gather!(n.by, ctx′) gather!(n.order_by, ctx′) n_ext_refs = length(refs) diff --git a/src/nodes/function.jl b/src/nodes/function.jl index aae32314..cc60dbf4 100644 --- a/src/nodes/function.jl +++ b/src/nodes/function.jl @@ -160,9 +160,6 @@ const funsql_fun = Fun transliterate(::typeof(Fun), name::Symbol, ctx::TransliterateContext, @nospecialize(args...)) = Fun(name, args = [transliterate(SQLQuery, arg, ctx) for arg in args]) -terminal(::Type{FunctionNode}) = - true - PrettyPrinting.quoteof(n::FunctionNode, ctx::QuoteContext) = Expr(:call, Expr(:., :Fun, diff --git a/src/nodes/literal.jl b/src/nodes/literal.jl index f4ff534e..593be93d 100644 --- a/src/nodes/literal.jl +++ b/src/nodes/literal.jl @@ -45,8 +45,5 @@ Base.convert(::Type{SQLQuery}, val::SQLLiteralType) = Base.convert(::Type{SQLQuery}, ref::Base.RefValue) = Lit(ref.x) -terminal(::Type{LiteralNode}) = - true - PrettyPrinting.quoteof(n::LiteralNode, ctx::QuoteContext) = Expr(:call, :Lit, n.val) diff --git a/src/resolve.jl b/src/resolve.jl index 2628e7f9..5608f7c1 100644 --- a/src/resolve.jl +++ b/src/resolve.jl @@ -342,6 +342,10 @@ function resolve(n::FromNode, ctx) end function resolve_scalar(n::FunctionNode, ctx) + if ctx.tail !== nothing + q′ = unnest(ctx.tail, convert(SQLQuery, n), ctx) + return resolve_scalar(q′, ctx) + end args′ = resolve_scalar(n.args, ctx) q′ = Fun(name = n.name, args = args′) Resolved(ScalarType(), tail = q′) @@ -367,7 +371,7 @@ resolve_scalar(n::FunSQLMacroNode, ctx) = function resolve(n::GetNode, ctx) if ctx.tail !== nothing - q′ = unnest(ctx.tail, Get(n.name), ctx) + q′ = unnest(ctx.tail, convert(SQLQuery, n), ctx) return resolve(q′, ctx) end resolve(FromNode(n.name), ctx) @@ -375,7 +379,7 @@ end function resolve_scalar(n::GetNode, ctx) if ctx.tail !== nothing - q′ = unnest(ctx.tail, Get(n.name), ctx) + q′ = unnest(ctx.tail, convert(SQLQuery, n), ctx) return resolve_scalar(q′, ctx) end t = get(ctx.row_type.fields, n.name, EmptyType()) @@ -406,7 +410,7 @@ function resolve(n::GroupNode, ctx) if n.name !== nothing push!(private_fields, n.name) end - q′ = Group(by = by′, sets = n.sets, label_map = n.label_map, tail = tail′) + q′ = Group(by = by′, sets = n.sets, name = n.name, label_map = n.label_map, tail = tail′) Resolved(RowType(fields, group, private_fields), tail = q′) end @@ -477,6 +481,10 @@ function resolve(n::LimitNode, ctx) end function resolve_scalar(n::LiteralNode, ctx) + if ctx.tail !== nothing + q′ = unnest(ctx.tail, convert(SQLQuery, n), ctx) + return resolve_scalar(q′, ctx) + end Resolved(ScalarType(), tail = convert(SQLQuery, n)) end diff --git a/src/translate.jl b/src/translate.jl index 2b124019..a616b1ef 100644 --- a/src/translate.jl +++ b/src/translate.jl @@ -175,6 +175,7 @@ struct TranslateContext refs::Vector{SQLQuery} vars::Base.ImmutableDict{Tuple{Symbol, Int}, SQLSyntax} subs::Dict{SQLQuery, SQLSyntax} + partition::Union{SQLSyntax, Nothing} TranslateContext(; catalog, defs) = new(catalog, @@ -187,9 +188,10 @@ struct TranslateContext 0, SQLQuery[], Base.ImmutableDict{Tuple{Symbol, Int}, SQLSyntax}(), - Dict{Int, SQLSyntax}()) + Dict{Int, SQLSyntax}(), + nothing) - function TranslateContext(ctx::TranslateContext; tail = ctx.tail, cte_map = ctx.cte_map, knot = ctx.knot, refs = ctx.refs, vars = ctx.vars, subs = ctx.subs) + function TranslateContext(ctx::TranslateContext; tail = ctx.tail, cte_map = ctx.cte_map, knot = ctx.knot, refs = ctx.refs, vars = ctx.vars, subs = ctx.subs, partition = ctx.partition) new(ctx.catalog, tail, ctx.defs, @@ -200,7 +202,8 @@ struct TranslateContext knot, refs, vars, - subs) + subs, + partition) end end @@ -265,9 +268,10 @@ function translate(q, ctx::TranslateContext, subs::Dict{SQLQuery, SQLSyntax}) end function translate(n::AggregateNode, ctx) - args = translate(n.args, ctx) - filter = translate(n.filter, ctx) - AGG(n.name, args = args, filter = filter) + ctx′ = ctx.partition !== nothing ? TranslateContext(ctx, partition = nothing) : ctx + args = translate(n.args, ctx′) + filter = translate(n.filter, ctx′) + AGG(n.name, args = args, filter = filter, over = ctx.partition) end function translate(n::AsNode, ctx) @@ -453,7 +457,6 @@ function assemble(n::DefineNode, ctx) for (f, i) in n.label_map tr_cache[f] = translate(n.args[i], ctx, subs) end - repl = Dict{SQLQuery, Symbol}() trns = Pair{SQLQuery, SQLSyntax}[] for ref in ctx.refs if @dissect(ref, nothing |> Get(name = (local name))) && name in keys(tr_cache) @@ -606,7 +609,9 @@ function assemble(n::FromValuesNode, ctx) end function assemble(n::GroupNode, ctx) - has_aggregates = any(ref -> @dissect(ref, Agg() || Agg() |> Nested()), ctx.refs) + has_aggregates = + n.name === nothing && any(ref -> @dissect(ref, Agg()), ctx.refs) || + any(ref -> @dissect(ref, Nested(name = (local name))) && name == n.name, ctx.refs) if isempty(n.by) && !has_aggregates # NOOP: already processed in link() return assemble(nothing, ctx) end @@ -624,9 +629,9 @@ function assemble(n::GroupNode, ctx) if @dissect(ref, nothing |> Get(name = (local name))) @assert name in keys(n.label_map) push!(trns, ref => by[n.label_map[name]]) - elseif @dissect(ref, nothing |> Agg()) + elseif n.name === nothing && @dissect(ref, nothing |> Agg()) push!(trns, ref => translate(ref, ctx, subs)) - elseif @dissect(ref, (local tail = nothing |> Agg()) |> Nested()) + elseif @dissect(ref, (local tail) |> Nested(name = (local name))) && name == n.name push!(trns, ref => translate(tail, ctx, subs)) end end @@ -648,24 +653,30 @@ function assemble(n::GroupNode, ctx) end function assemble(n::IntoNode, ctx) - refs′ = SQLQuery[] - for ref in ctx.refs - if @dissect(ref, (local tail) |> Nested()) - push!(refs′, tail) - else - push!(refs′, ref) + base = assemble(ctx) + if all(@dissect(ref, (local tail) |> Nested()) && tail in keys(base.repl) for ref in ctx.refs) + repl′ = Dict{SQLQuery, Symbol}() + for ref in ctx.refs + @dissect(ref, (local tail) |> Nested()) || error() + repl′[ref] = base.repl[tail] end + return Assemblage(n.name, base.syntax, cols = base.cols, repl = repl′) + end + if !@dissect(base.syntax, SELECT() || UNION()) + base_alias = nothing + s = base.syntax + else + base_alias = allocate_alias(ctx, base) + s = FROM(AS(name = base_alias, tail = complete(base))) end - base = assemble(TranslateContext(ctx, refs = refs′)) - repl′ = Dict{SQLQuery, Symbol}() + subs = make_subs(base, base_alias) + trns = Pair{SQLQuery, SQLSyntax}[] for ref in ctx.refs - if @dissect(ref, (local tail) |> Nested()) - repl′[ref] = base.repl[tail] - else - repl′[ref] = base.repl[ref] - end + @dissect(ref, (local tail) |> Nested()) || error() + push!(trns, ref => translate(tail, ctx, subs)) end - Assemblage(n.name, base.syntax, cols = base.cols, repl = repl′) + repl, cols = make_repl_cols(trns) + Assemblage(base.name, s, cols = cols, repl = repl) end function assemble(n::IterateNode, ctx) @@ -839,14 +850,13 @@ function assemble(n::PartitionNode, ctx) partition = PARTITION(by = by, order_by = order_by, frame = n.frame) trns = Pair{SQLQuery, SQLSyntax}[] has_aggregates = false + ctx′′ = TranslateContext(ctx′, partition = partition) for ref in ctx.refs if @dissect(ref, nothing |> Agg()) && n.name === nothing - @dissect(translate(ref, ctx′), AGG(name = (local name), args = (local args), filter = (local filter))) || error() - push!(trns, ref => AGG(; name, args, filter, over = partition)) + push!(trns, ref => translate(ref, ctx′′)) has_aggregates = true - elseif @dissect(ref, (local tail = nothing |> Agg()) |> Nested(name = (local name))) && name === n.name - @dissect(translate(tail, ctx′), AGG(name = (local name), args = (local args), filter = (local filter))) || error() - push!(trns, ref => AGG(; name, args, filter, over = partition)) + elseif @dissect(ref, (local tail) |> Nested(name = (local name))) && name === n.name + push!(trns, ref => translate(tail, ctx′′)) has_aggregates = true else push!(trns, ref => subs[ref])