diff --git a/Project.toml b/Project.toml index 2d76c69e..bea4bf08 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "FunSQL" uuid = "cf6cc811-59f4-4a10-b258-a8547a8f6407" -authors = ["Kirill Simonov ", "Clark C. Evans "] version = "0.15.0" +authors = ["Kirill Simonov ", "Clark C. Evans "] [deps] DBInterface = "a10d1c49-ce27-4219-8d33-6db1a4562965" @@ -11,6 +11,7 @@ LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" PrettyPrinting = "54e16d92-306c-5ea0-a30b-337be88ac337" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +URIs = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" [compat] DBInterface = "2.5" @@ -19,4 +20,5 @@ LRUCache = "1.3" OrderedCollections = "1.4" PrettyPrinting = "0.3.2, 0.4" Tables = "1.6" +URIs = "1.6.1" julia = "1.10" diff --git a/src/FunSQL.jl b/src/FunSQL.jl index fb6a5ca8..1f99cf5a 100644 --- a/src/FunSQL.jl +++ b/src/FunSQL.jl @@ -56,6 +56,7 @@ export funsql_group, funsql_highlight, funsql_in, + funsql_into, funsql_iterate, funsql_is_not_null, funsql_is_null, @@ -94,6 +95,7 @@ using Tables using DBInterface using LRUCache using DataAPI +using URIs const SQLLiteralType = Union{Missing, Bool, Number, AbstractString, Dates.AbstractTime} diff --git a/src/catalogs.jl b/src/catalogs.jl index 0fbd446a..1fe7f2e6 100644 --- a/src/catalogs.jl +++ b/src/catalogs.jl @@ -31,22 +31,24 @@ _metadata_get(dict::SQLMetadata, key::Union{Symbol, AbstractString}, default; st end """ - SQLColumn(; name, metadata = nothing) - SQLColumn(name; metadata = nothing) + SQLColumn(; name, private = false, metadata = nothing) + SQLColumn(name; private = false, metadata = nothing) `SQLColumn` represents a column with the given `name` and optional `metadata`. +If `private` is `true`, the column is excluded from the default query output. """ struct SQLColumn name::Symbol + private::Bool metadata::SQLMetadata - function SQLColumn(; name::Union{Symbol, AbstractString}, metadata = nothing) - new(Symbol(name), _metadata(metadata)) + function SQLColumn(; name::Union{Symbol, AbstractString}, private = false, metadata = nothing) + new(Symbol(name), private, _metadata(metadata)) end end -SQLColumn(name; metadata = nothing) = - SQLColumn(name = name, metadata = metadata) +SQLColumn(name; private = false, metadata = nothing) = + SQLColumn(; name, private, metadata) Base.show(io::IO, col::SQLColumn) = print(io, quoteof(col, limit = true)) @@ -56,6 +58,9 @@ Base.show(io::IO, ::MIME"text/plain", col::SQLColumn) = function PrettyPrinting.quoteof(col::SQLColumn; limit::Bool = false) ex = Expr(:call, nameof(SQLColumn), QuoteNode(col.name)) + if col.private + push(ex.args, Expr(:kw, :private, col.private)) + end if !isempty(col.metadata) push!(ex.args, Expr(:kw, :metadata, limit ? :… : quoteof(reverse!(collect(col.metadata))))) end @@ -122,10 +127,10 @@ struct SQLTable <: AbstractDict{Symbol, SQLColumn} end SQLTable(name; qualifiers = Symbol[], columns, metadata = nothing) = - SQLTable(qualifiers = qualifiers, name = name, columns = columns, metadata = metadata) + SQLTable(; qualifiers, name, columns, metadata) SQLTable(name, columns...; qualifiers = Symbol[], metadata = nothing) = - SQLTable(qualifiers = qualifiers, name = name, columns = [columns...], metadata = metadata) + SQLTable(; qualifiers, name, columns = [columns...], metadata) _column_map(columns::OrderedDict{Symbol, SQLColumn}) = columns @@ -280,7 +285,7 @@ struct SQLCatalog <: AbstractDict{Symbol, SQLTable} end SQLCatalog(tables...; dialect = :default, cache = default_cache_maxsize, metadata = nothing) = - SQLCatalog(tables = tables, dialect = dialect, cache = cache, metadata = metadata) + SQLCatalog(; tables, dialect, cache, metadata) _table_map(tables::Dict{Symbol, SQLTable}) = tables diff --git a/src/clauses/internal.jl b/src/clauses/internal.jl index 8eea0184..6f069dbe 100644 --- a/src/clauses/internal.jl +++ b/src/clauses/internal.jl @@ -4,10 +4,10 @@ struct WithContextClause <: AbstractSQLClause dialect::SQLDialect - columns::Union{Vector{SQLColumn}, Nothing} + table::Union{SQLTable, Nothing} - WithContextClause(; dialect, columns = nothing) = - new(dialect, columns) + WithContextClause(; dialect, table = nothing) = + new(dialect, table) end const WITH_CONTEXT = SQLSyntaxCtor{WithContextClause}(:WITH_CONTEXT) @@ -17,8 +17,8 @@ function PrettyPrinting.quoteof(c::WithContextClause, ctx::QuoteContext) if c.dialect !== default_dialect push!(ex.args, Expr(:kw, :dialect, quoteof(c.dialect))) end - if c.columns !== nothing - push!(ex.args, Expr(:kw, :columns, Expr(:vect, Any[quoteof(col) for col in c.columns]...))) + if c.table !== nothing + push!(ex.args, Expr(:kw, :table, quoteof(c.table))) end ex end diff --git a/src/link.jl b/src/link.jl index 00015f41..85aed317 100644 --- a/src/link.jl +++ b/src/link.jl @@ -25,18 +25,44 @@ struct LinkContext knot_refs) end +function _select(t::RowType) + args = Pair{SQLColumn, SQLQuery}[] + for (f, ft) in t.fields + !(f in t.private_fields) || continue + lbl = escapefield(f) + if ft isa ScalarType + push!(args, SQLColumn(lbl) => Get(f)) + else + nested_args = _select(ft) + for (nested_col, nested_ref) in nested_args + nested_lbl = Symbol(lbl, '.', nested_col.name) + nested_col = SQLColumn(nested_lbl) + push!(args, nested_col => Nested(name = f, tail = nested_ref)) + end + end + end + args +end + +escapefield(f::Symbol) = + Symbol(escapefield(string(f))) + +escapefield(f) = + escapeuri(f, c -> c == ' ' || c != '.' && URIs.issafe(c)) + function link(q::SQLQuery) @dissect(q, (local tail) |> WithContext(catalog = (local catalog))) || throw(IllFormedError()) ctx = LinkContext(catalog) t = row_type(tail) - refs = SQLQuery[] - for (f, ft) in t.fields - if ft isa ScalarType - push!(refs, Get(f)) - end + col_refs = _select(t) + columns = first.(col_refs) + if isempty(columns) + columns = [SQLColumn(:_)] end + table = SQLTable(escapefield(label(q)), columns = columns) + refs = last.(col_refs) tail′ = Linked(refs, tail = link(dismantle(tail, ctx), ctx, refs)) - WithContext(tail = tail′, catalog = catalog, defs = ctx.defs) + WithContext(tail = tail′, catalog = catalog, table = table, defs = ctx.defs) end function dismantle(q::SQLQuery, ctx) @@ -123,19 +149,15 @@ function dismantle(n::GroupNode, ctx) Group(by = by′, sets = n.sets, name = n.name, label_map = n.label_map, tail = tail′) end -function dismantle(n::IterateNode, ctx) +function dismantle(n::IntoNode, ctx) tail′ = dismantle(ctx) - iterator′ = dismantle(n.iterator, ctx) - Iterate(iterator = iterator′, tail = tail′) + Into(name = n.name, tail = tail′) end -function dismantle(n::JoinNode, ctx) - rt = row_type(n.joinee) - router = JoinRouter(Set(keys(rt.fields)), !isa(rt.group, EmptyType)) +function dismantle(n::IterateNode, ctx) tail′ = dismantle(ctx) - joinee′ = dismantle(n.joinee, ctx) - on′ = dismantle_scalar(n.on, ctx) - RoutedJoin(joinee = joinee′, on = on′, router = router, left = n.left, right = n.right, optional = n.optional, tail = tail′) + iterator′ = dismantle(n.iterator, ctx) + Iterate(iterator = iterator′, tail = tail′) end function dismantle(n::LimitNode, ctx) @@ -181,6 +203,13 @@ function dismantle_scalar(n::ResolvedNode, ctx) end end +function dismantle(n::RoutedJoinNode, ctx) + tail′ = dismantle(ctx) + joinee′ = dismantle(n.joinee, ctx) + on′ = dismantle_scalar(n.on, ctx) + RoutedJoin(joinee = joinee′, on = on′, name = n.name, left = n.left, right = n.right, optional = n.optional, tail = tail′) +end + function dismantle(n::SelectNode, ctx) tail′ = dismantle(ctx) args′ = dismantle_scalar(n.args, ctx) @@ -232,16 +261,7 @@ function link(n::AppendNode, ctx) end function link(n::AsNode, ctx) - refs = SQLQuery[] - for ref in ctx.refs - if @dissect(ref, (local tail) |> Nested(name = (local name))) - @assert name == n.name - push!(refs, tail) - else - error() - end - end - tail′ = link(ctx.tail, ctx, refs) + tail′ = link(ctx) As(name = n.name, tail = tail′) end @@ -289,17 +309,28 @@ function link(n::FromIterateNode, ctx) end function link(n::FromTableExpressionNode, ctx) - refs = ctx.cte_refs[(n.name, n.depth)] - for ref in ctx.refs - push!(refs, Nested(name = n.name, tail = ref)) - end + cte_refs = ctx.cte_refs[(n.name, n.depth)] + append!(cte_refs, ctx.refs) n end function link(n::GroupNode, ctx) - has_aggregates = any(ref -> @dissect(ref, Agg() || Agg() |> Nested()), ctx.refs) - if !has_aggregates && isempty(n.by) - return link(FromNothing(), ctx) + krefs = SQLQuery[] + arefs = SQLQuery[] + has_aggregates = false + for ref in ctx.refs + if @dissect(ref, (local tail) |> Nested(name = (local name))) && name === n.name + gather!(tail, ctx, arefs) + has_aggregates = true + elseif @dissect(ref, Agg()) && n.name === nothing + push!(arefs, ref) + has_aggregates = true + else + push!(krefs, ref) + end + end + if isempty(arefs) && isempty(n.by) + return link(n.name !== nothing ? Into(n.name, tail = FromNothing()) : FromNothing(), ctx) end # Some group keys are added both to SELECT and to GROUP BY. # To avoid duplicate SQL, they must be evaluated in a nested subquery. @@ -312,16 +343,16 @@ function link(n::GroupNode, ctx) # Ignore `SELECT DISTINCT` case. if has_aggregates ctx′ = LinkContext(ctx, refs = refs) - for ref in ctx.refs - if (@dissect(ref, nothing |> Agg(args = (local args), filter = (local filter)) |> Nested(name = (local name))) && name === n.name) || - (@dissect(ref, nothing |> Agg(args = (local args), filter = (local filter))) && n.name === nothing) - gather!(args, ctx′) - if filter !== nothing - gather!(filter, ctx′) - end - elseif @dissect(ref, nothing |> Get(name = (local name))) && name in keys(n.label_map) - # Force evaluation in a nested subquery. - push!(refs, n.by[n.label_map[name]]) + for ref in krefs + @dissect(ref, Get(name = (local name))) && name in keys(n.label_map) || error() + # Force evaluation in a nested subquery. + push!(refs, n.by[n.label_map[name]]) + end + for ref in arefs + @dissect(ref, Agg(args = (local args), filter = (local filter))) || error() + gather!(args, ctx′) + if filter !== nothing + gather!(filter, ctx′) end end end @@ -333,6 +364,20 @@ function link(n::GroupNode, ctx) Group(by = n.by, sets = n.sets, name = n.name, label_map = n.label_map, tail = tail′) end +function link(n::IntoNode, ctx) + refs = SQLQuery[] + for ref in ctx.refs + if @dissect(ref, (local tail) |> Nested(name = (local name))) + @assert name == n.name + gather!(tail, ctx, refs) + else + error() + end + end + tail′ = Linked(refs, 0, tail = link(ctx.tail, ctx, refs)) + Into(name = n.name, tail = tail′) +end + function link(n::IterateNode, ctx) iterator′ = n.iterator defs = copy(ctx.defs) @@ -364,53 +409,6 @@ function link(n::IterateNode, ctx) Padding(tail = q′) end -function route(r::JoinRouter, ref::SQLQuery) - if @dissect(ref, Nested(name = (local name))) && name in r.label_set - return 1 - end - if @dissect(ref, Get(name = (local name))) && name in r.label_set - return 1 - end - if @dissect(ref, Agg()) && r.group - return 1 - end - return -1 -end - -function link(n::RoutedJoinNode, ctx) - lrefs = SQLQuery[] - rrefs = SQLQuery[] - for ref in ctx.refs - turn = route(n.router, ref) - push!(turn < 0 ? lrefs : rrefs, ref) - end - if n.optional && isempty(rrefs) - return link(ctx) - end - ln_ext_refs = length(lrefs) - rn_ext_refs = length(rrefs) - refs′ = SQLQuery[] - lateral_refs = SQLQuery[] - gather!(n.joinee, ctx, lateral_refs) - append!(lrefs, lateral_refs) - lateral = !isempty(lateral_refs) - gather!(n.on, ctx, refs′) - for ref in refs′ - turn = route(n.router, ref) - push!(turn < 0 ? lrefs : rrefs, ref) - end - tail′ = Linked(lrefs, ln_ext_refs, tail = link(ctx.tail, ctx, lrefs)) - joinee′ = Linked(rrefs, rn_ext_refs, tail = link(n.joinee, ctx, rrefs)) - RoutedJoin( - joinee = joinee′, - on = n.on, - router = n.router, - left = n.left, - right = n.right, - lateral = lateral, - tail = tail′) -end - function link(n::LimitNode, ctx) tail′ = Linked(ctx.refs, tail = link(ctx)) Limit(offset = n.offset, limit = n.limit, tail = tail′) @@ -433,16 +431,14 @@ end function link(n::PartitionNode, ctx) refs = SQLQuery[] - imm_refs = SQLQuery[] - ctx′ = LinkContext(ctx, refs = imm_refs) + arefs = SQLQuery[] has_aggregates = false for ref in ctx.refs - if (@dissect(ref, nothing |> Agg(args = (local args), filter = (local filter)) |> Nested(name = (local name))) && name === n.name) || - (@dissect(ref, nothing |> Agg(args = (local args), filter = (local filter))) && n.name === nothing) - gather!(args, ctx′) - if filter !== nothing - gather!(filter, ctx′) - end + if @dissect(ref, (local tail) |> Nested(name = (local name))) && name === n.name + gather!(tail, ctx, arefs) + has_aggregates = true + elseif @dissect(ref, Agg()) && n.name === nothing + push!(arefs, ref) has_aggregates = true else push!(refs, ref) @@ -451,6 +447,15 @@ function link(n::PartitionNode, ctx) if !has_aggregates return link(ctx) end + imm_refs = SQLQuery[] + ctx′ = LinkContext(ctx, refs = imm_refs) + for ref in arefs + @dissect(ref, Agg(args = (local args), filter = (local filter))) || error() + gather!(args, ctx′) + if filter !== nothing + gather!(filter, ctx′) + end + end gather!(n.by, ctx′) gather!(n.order_by, ctx′) n_ext_refs = length(refs) @@ -459,6 +464,46 @@ function link(n::PartitionNode, ctx) Partition(by = n.by, order_by = n.order_by, frame = n.frame, name = n.name, tail = tail′) end +function link(n::RoutedJoinNode, ctx) + lrefs = SQLQuery[] + rrefs = SQLQuery[] + for ref in ctx.refs + if @dissect(ref, Nested(name = (local name))) && name === n.name + push!(rrefs, ref) + else + push!(lrefs, ref) + end + end + if n.optional && isempty(rrefs) + return link(ctx) + end + ln_ext_refs = length(lrefs) + rn_ext_refs = length(rrefs) + refs′ = SQLQuery[] + lateral_refs = SQLQuery[] + gather!(n.joinee, ctx, lateral_refs) + append!(lrefs, lateral_refs) + lateral = !isempty(lateral_refs) + gather!(n.on, ctx, refs′) + for ref in refs′ + if @dissect(ref, Nested(name = (local name))) && name === n.name + push!(rrefs, ref) + else + push!(lrefs, ref) + end + end + tail′ = Linked(lrefs, ln_ext_refs, tail = link(ctx.tail, ctx, lrefs)) + joinee′ = Linked(rrefs, rn_ext_refs, tail = link(Into(name = n.name, tail = n.joinee), ctx, rrefs)) + RoutedJoin( + joinee = joinee′, + on = n.on, + name = n.name, + left = n.left, + right = n.right, + lateral = lateral, + tail = tail′) +end + function link(n::SelectNode, ctx) refs = SQLQuery[] gather!(n.args, ctx, refs) @@ -556,12 +601,9 @@ end function gather!(n::IsolatedNode, ctx) def = ctx.defs[n.idx] !@dissect(def, Linked()) || return - refs = SQLQuery[] - for (f, ft) in n.type.fields - if ft isa ScalarType - push!(refs, Get(f)) - break - end + refs = last.(_select(n.type)) + if !isempty(refs) + refs = refs[1:1] end def′ = Linked(refs, tail = link(def, ctx, refs)) ctx.defs[n.idx] = def′ diff --git a/src/nodes.jl b/src/nodes.jl index ab470726..b8e68a13 100644 --- a/src/nodes.jl +++ b/src/nodes.jl @@ -54,17 +54,19 @@ terminal(q::SQLQuery) = Chain(q′, q) = convert(SQLQuery, q)(q′) -label(q::SQLQuery) = - @something label(q.head) label(q.tail) +function label(q::SQLQuery; default = :_) + l = label(q.head) + l !== nothing ? l : label(q.tail; default) +end label(n::AbstractSQLNode) = nothing -label(::Nothing) = - :_ +label(::Nothing; default = :_) = + default -label(q) = - label(convert(SQLQuery, q)) +label(q; default = :_) = + label(convert(SQLQuery, q); default) # A variant of SQLQuery for assembling a chain of identifiers. @@ -913,6 +915,7 @@ include("nodes/get.jl") include("nodes/group.jl") include("nodes/highlight.jl") include("nodes/internal.jl") +include("nodes/into.jl") include("nodes/iterate.jl") include("nodes/join.jl") include("nodes/limit.jl") diff --git a/src/nodes/as.jl b/src/nodes/as.jl index 3a4a5db7..d1a61a2d 100644 --- a/src/nodes/as.jl +++ b/src/nodes/as.jl @@ -16,8 +16,7 @@ AsNode(name) = As(name; tail = nothing) name => tail -In a scalar context, `As` specifies the name of the output column. When -applied to tabular data, `As` wraps the data in a nested record. +`As` specifies the name of the output column. The arrow operator (`=>`) is a shorthand notation for `As`. @@ -35,19 +34,19 @@ SELECT "person_1"."person_id" AS "id" FROM "person" AS "person_1" ``` -*Show all patients together with their state of residence.* +*Show all patients together with their primary care provider.* ```jldoctest -julia> person = SQLTable(:person, columns = [:person_id, :year_of_birth, :location_id]); +julia> person = SQLTable(:person, columns = [:person_id, :year_of_birth, :provider_id]); -julia> location = SQLTable(:location, columns = [:location_id, :state]); +julia> provider = SQLTable(:provider, columns = [:provider_id, :provider_name]); julia> q = From(:person) |> - Join(From(:location) |> As(:location), - on = Get.location_id .== Get.location.location_id) |> - Select(Get.person_id, Get.location.state); + Join(From(:provider) |> As(:pcp), + on = Get.provider_id .== Get.pcp.provider_id) |> + Select(Get.person_id, Get.pcp.provider_name); -julia> print(render(q, tables = [person, location])) +julia> print(render(q, tables = [person, provider])) SELECT "person_1"."person_id", "location_1"."state" diff --git a/src/nodes/define.jl b/src/nodes/define.jl index fab058de..3fd44dfc 100644 --- a/src/nodes/define.jl +++ b/src/nodes/define.jl @@ -4,13 +4,14 @@ struct DefineNode <: TabularNode args::Vector{SQLQuery} before::Union{Symbol, Bool} after::Union{Symbol, Bool} + private::Bool label_map::OrderedDict{Symbol, Int} - function DefineNode(; args = [], before = nothing, after = nothing, label_map = nothing) + function DefineNode(; args = [], before = nothing, after = nothing, private = false, label_map = nothing) if label_map !== nothing - n = new(args, something(before, false), something(after, false), label_map) + n = new(args, something(before, false), something(after, false), private, label_map) else - n = new(args, something(before, false), something(after, false), OrderedDict{Symbol, Int}()) + n = new(args, something(before, false), something(after, false), private, OrderedDict{Symbol, Int}()) populate_label_map!(n) end if (n.before isa Symbol || n.before) && (n.after isa Symbol || n.after) @@ -20,12 +21,12 @@ struct DefineNode <: TabularNode end end -DefineNode(args...; before = nothing, after = nothing) = - DefineNode(args = SQLQuery[args...], before = before, after = after) +DefineNode(args...; before = nothing, after = nothing, private = false) = + DefineNode(args = SQLQuery[args...], before = before, after = after, private = private) """ - Define(; args = [], before = nothing, after = nothing, tail = nothing) - Define(args...; before = nothing, after = nothing, tail = nothing) + Define(; args = [], before = nothing, after = nothing, private = false, tail = nothing) + Define(args...; before = nothing, after = nothing, private = false, tail = nothing) The `Define` node adds or replaces output columns. @@ -35,6 +36,8 @@ both new and replaced columns at the end (after a specified column). Alternatively, set `before = true` (`before = `) to add both new and replaced columns at the front (before the specified column). +If `private` is set, the columns will be excluded from the query output. + # Examples *Show patients who are at least 16 years old.* @@ -90,5 +93,8 @@ function PrettyPrinting.quoteof(n::DefineNode, ctx::QuoteContext) if n.after !== false push!(ex.args, Expr(:kw, :after, n.after isa Symbol ? QuoteNode(n.after) : n.after)) end + if n.private !== false + push!(ex.args, Expr(:kw, :private, n.private)) + end ex end diff --git a/src/nodes/function.jl b/src/nodes/function.jl index aae32314..cc60dbf4 100644 --- a/src/nodes/function.jl +++ b/src/nodes/function.jl @@ -160,9 +160,6 @@ const funsql_fun = Fun transliterate(::typeof(Fun), name::Symbol, ctx::TransliterateContext, @nospecialize(args...)) = Fun(name, args = [transliterate(SQLQuery, arg, ctx) for arg in args]) -terminal(::Type{FunctionNode}) = - true - PrettyPrinting.quoteof(n::FunctionNode, ctx::QuoteContext) = Expr(:call, Expr(:., :Fun, diff --git a/src/nodes/internal.jl b/src/nodes/internal.jl index 91f866fd..5ec0bd44 100644 --- a/src/nodes/internal.jl +++ b/src/nodes/internal.jl @@ -4,9 +4,10 @@ struct WithContextNode <: AbstractSQLNode catalog::SQLCatalog defs::Vector{SQLQuery} + table::Union{SQLTable, Nothing} - WithContextNode(; catalog = SQLCatalog(), defs = SQLQuery[]) = - new(catalog, defs) + WithContextNode(; catalog = SQLCatalog(), defs = SQLQuery[], table = nothing) = + new(catalog, defs, table) end const WithContext = SQLQueryCtor{WithContextNode}(:WithContext) @@ -17,6 +18,9 @@ function PrettyPrinting.quoteof(n::WithContextNode, ctx::QuoteContext) if !isempty(n.defs) push!(ex.args, Expr(:kw, :defs, Expr(:vect, Any[quoteof(def, ctx) for def in n.defs]...))) end + if n.table !== nothing + push!(ex.args, Expr(:kw, :table, quoteof(n.table))) + end ex end @@ -203,29 +207,21 @@ PrettyPrinting.quoteof(n::FromFunctionNode, ctx::QuoteContext) = # Annotated Join node. -struct JoinRouter - label_set::Set{Symbol} - group::Bool -end - -PrettyPrinting.quoteof(r::JoinRouter) = - Expr(:call, :JoinRouter, quoteof(r.label_set), quoteof(r.group)) - struct RoutedJoinNode <: TabularNode joinee::SQLQuery on::SQLQuery - router::JoinRouter + name::Symbol left::Bool right::Bool lateral::Bool optional::Bool - RoutedJoinNode(; joinee, on, router, left, right, lateral = false, optional = false) = - new(joinee, on, router, left, right, lateral, optional) + RoutedJoinNode(; joinee, on, name = label(joinee), left, right, lateral = false, optional = false) = + new(joinee, on, name, left, right, lateral, optional) end -RoutedJoinNode(joinee, on; router, left = false, right = false, lateral = false, optional = false) = - RoutedJoinNode(joinee = joinee, on = on, router, left = left, right = right, lateral = lateral, optional = optional) +RoutedJoinNode(joinee, on; name = label(joinee), left = false, right = false, lateral = false, optional = false) = + RoutedJoinNode(name = name, on = on, router, left = left, right = right, lateral = lateral, optional = optional) const RoutedJoin = SQLQueryCtor{RoutedJoinNode}(:RoutedJoin) @@ -234,7 +230,7 @@ function PrettyPrinting.quoteof(n::RoutedJoinNode, ctx::QuoteContext) if !ctx.limit push!(ex.args, quoteof(n.joinee, ctx)) push!(ex.args, quoteof(n.on, ctx)) - push!(ex.args, Expr(:kw, :router, quoteof(n.router))) + push!(ex.args, Expr(:kw, :name, QuoteNode(n.name))) if n.left push!(ex.args, Expr(:kw, :left, n.left)) end @@ -310,7 +306,7 @@ PrettyPrinting.quoteof(n::FunSQLMacroNode, ctx::QuoteContext) = Expr(:macrocall, Symbol("@funsql"), n.line, !ctx.limit ? n.ex : :…) label(n::FunSQLMacroNode) = - label(n.query) + label(n.query, default = nothing) # Unwrap @funsql macro when displaying the query. diff --git a/src/nodes/into.jl b/src/nodes/into.jl new file mode 100644 index 00000000..655652c0 --- /dev/null +++ b/src/nodes/into.jl @@ -0,0 +1,33 @@ +# Wrap the output into a nested record. + +mutable struct IntoNode <: TabularNode + name::Symbol + private::Bool + + IntoNode(; name::Union{Symbol, AbstractString}, private::Bool = false) = + new(Symbol(name), private) +end + +IntoNode(name; private = false) = + IntoNode(; name, private) + +""" + Into(; name, private = false, tail = nothing) + Into(name; private = false, tail = nothing) + +`Into` wraps output columns in a nested record. +""" +const Into = SQLQueryCtor{IntoNode}(:Into) + +const funsql_into = Into + +function PrettyPrinting.quoteof(n::IntoNode, ctx::QuoteContext) + ex = Expr(:call, :Into, quoteof(n.name)) + if n.private + push!(ex.args, Expr(:kw, :private, n.private)) + end + ex +end + +label(n::IntoNode) = + n.name diff --git a/src/nodes/join.jl b/src/nodes/join.jl index c1bc2eb4..027914f3 100644 --- a/src/nodes/join.jl +++ b/src/nodes/join.jl @@ -6,21 +6,23 @@ mutable struct JoinNode <: TabularNode left::Bool right::Bool optional::Bool + swap::Bool + private::Bool - JoinNode(; joinee, on, left = false, right = false, optional = false) = - new(joinee, on, left, right, optional) + JoinNode(; joinee, on, left = false, right = false, optional = false, swap = false, private = false) = + new(joinee, on, left, right, optional, swap, private) end -JoinNode(joinee; on, left = false, right = false, optional = false) = - JoinNode(; joinee, on, left, right, optional) +JoinNode(joinee; on, left = false, right = false, optional = false, swap = false, private = false) = + JoinNode(; joinee, on, left, right, optional, swap, private) -JoinNode(joinee, on; left = false, right = false, optional = false) = - JoinNode(; joinee, on, left, right, optional) +JoinNode(joinee, on; left = false, right = false, optional = false, swap = false, private = false) = + JoinNode(; joinee, on, left, right, optional, swap, private) """ - Join(; joinee, on, left = false, right = false, optional = false) - Join(joinee; on, left = false, right = false, optional = false) - Join(joinee, on; left = false, right = false, optional = false) + Join(; joinee, on, left = false, right = false, optional = false, swap = false, private = false) + Join(joinee; on, left = false, right = false, optional = false, swap = false, private = false) + Join(joinee, on; left = false, right = false, optional = false, swap = false, private = false) `Join` correlates two input datasets. @@ -102,8 +104,17 @@ function PrettyPrinting.quoteof(n::JoinNode, ctx::QuoteContext) if n.optional push!(ex.args, Expr(:kw, :optional, n.optional)) end + if n.swap + push!(ex.args, Expr(:kw, :swap, n.swap)) + end + if n.private + push!(ex.args, Expr(:kw, :private, n.private)) + end else push!(ex.args, :…) end ex end + +label(n::JoinNode) = + n.swap ? label(n.joinee) : nothing diff --git a/src/nodes/literal.jl b/src/nodes/literal.jl index f4ff534e..593be93d 100644 --- a/src/nodes/literal.jl +++ b/src/nodes/literal.jl @@ -45,8 +45,5 @@ Base.convert(::Type{SQLQuery}, val::SQLLiteralType) = Base.convert(::Type{SQLQuery}, ref::Base.RefValue) = Lit(ref.x) -terminal(::Type{LiteralNode}) = - true - PrettyPrinting.quoteof(n::LiteralNode, ctx::QuoteContext) = Expr(:call, :Lit, n.val) diff --git a/src/resolve.jl b/src/resolve.jl index 73bf67ac..5608f7c1 100644 --- a/src/resolve.jl +++ b/src/resolve.jl @@ -178,9 +178,8 @@ end function resolve(n::AsNode, ctx) tail′ = resolve(ctx) - t = row_type(tail′) q′ = As(name = n.name, tail = tail′) - Resolved(RowType(FieldTypeMap(n.name => t)), tail = q′) + Resolved(type(tail′), tail = q′) end function resolve_scalar(n::AsNode, ctx) @@ -264,16 +263,28 @@ function resolve(n::DefineNode, ctx) end end end + private_fields = copy(t.private_fields) + for l in keys(n.label_map) + if n.private + push!(private_fields, l) + else + delete!(private_fields, l) + end + end q′ = Define(args = args′, label_map = n.label_map, tail = tail′) - Resolved(RowType(fields, t.group), tail = q′) + Resolved(RowType(fields, t.group, private_fields), tail = q′) end function RowType(table::SQLTable) fields = FieldTypeMap() - for f in keys(table.columns) + private_fields = Set{Symbol}() + for (f, c) in table.columns fields[f] = ScalarType() + if c.private + push!(private_fields, f) + end end - RowType(fields) + RowType(fields, EmptyType(), private_fields) end function resolve(n::FromNode, ctx) @@ -331,6 +342,10 @@ function resolve(n::FromNode, ctx) end function resolve_scalar(n::FunctionNode, ctx) + if ctx.tail !== nothing + q′ = unnest(ctx.tail, convert(SQLQuery, n), ctx) + return resolve_scalar(q′, ctx) + end args′ = resolve_scalar(n.args, ctx) q′ = Fun(name = n.name, args = args′) Resolved(ScalarType(), tail = q′) @@ -356,7 +371,7 @@ resolve_scalar(n::FunSQLMacroNode, ctx) = function resolve(n::GetNode, ctx) if ctx.tail !== nothing - q′ = unnest(ctx.tail, Get(n.name), ctx) + q′ = unnest(ctx.tail, convert(SQLQuery, n), ctx) return resolve(q′, ctx) end resolve(FromNode(n.name), ctx) @@ -364,7 +379,7 @@ end function resolve_scalar(n::GetNode, ctx) if ctx.tail !== nothing - q′ = unnest(ctx.tail, Get(n.name), ctx) + q′ = unnest(ctx.tail, convert(SQLQuery, n), ctx) return resolve_scalar(q′, ctx) end t = get(ctx.row_type.fields, n.name, EmptyType()) @@ -391,8 +406,12 @@ function resolve(n::GroupNode, ctx) fields[n.name] = RowType(FieldTypeMap(), group) group = EmptyType() end - q′ = Group(by = by′, sets = n.sets, label_map = n.label_map, tail = tail′) - Resolved(RowType(fields, group), tail = q′) + private_fields = Set{Symbol}() + if n.name !== nothing + push!(private_fields, n.name) + end + q′ = Group(by = by′, sets = n.sets, name = n.name, label_map = n.label_map, tail = tail′) + Resolved(RowType(fields, group, private_fields), tail = q′) end resolve(::HighlightNode, ctx) = @@ -401,6 +420,14 @@ resolve(::HighlightNode, ctx) = resolve_scalar(::HighlightNode, ctx) = resolve_scalar(ctx) +function resolve(n::IntoNode, ctx) + tail′ = resolve(ctx) + t = row_type(tail′) + q′ = Into(name = n.name, tail = tail′) + t′ = RowType(FieldTypeMap(n.name => t), EmptyType(), n.private ? Set([n.name]) : Set{Symbol}()) + Resolved(t′, tail = q′) +end + function resolve(n::IterateNode, ctx) tail′ = resolve(ResolveContext(ctx, knot_type = nothing, implicit_knot = false)) t = row_type(tail′) @@ -416,23 +443,30 @@ function resolve(n::IterateNode, ctx) end function resolve(n::JoinNode, ctx) + if n.swap + ctx′ = ResolveContext(ctx, tail = n.joinee) + return resolve(JoinNode(joinee = ctx.tail, on = n.on, left = n.right, right = n.left, optional = n.optional, private = n.private), ctx′) + end tail′ = resolve(ctx) lt = row_type(tail′) + name = label(n.joinee) joinee′ = resolve(n.joinee, ResolveContext(ctx, row_type = lt, implicit_knot = false)) rt = row_type(joinee′) fields = FieldTypeMap() for (f, ft) in lt.fields - fields[f] = get(rt.fields, f, ft) + fields[f] = ft end - for (f, ft) in rt.fields - if !haskey(fields, f) - fields[f] = ft - end + fields[name] = rt + group = lt.group + private_fields = copy(lt.private_fields) + if n.private + push!(private_fields, name) + else + delete!(private_fields, name) end - group = rt.group isa EmptyType ? lt.group : rt.group - t = RowType(fields, group) + t = RowType(fields, group, private_fields) on′ = resolve_scalar(n.on, ctx, t) - q′ = Join(joinee = joinee′, on = on′, left = n.left, right = n.right, optional = n.optional, tail = tail′) + q′ = RoutedJoin(joinee = joinee′, on = on′, name = name, left = n.left, right = n.right, optional = n.optional, tail = tail′) Resolved(t, tail = q′) end @@ -447,6 +481,10 @@ function resolve(n::LimitNode, ctx) end function resolve_scalar(n::LiteralNode, ctx) + if ctx.tail !== nothing + q′ = unnest(ctx.tail, convert(SQLQuery, n), ctx) + return resolve_scalar(q′, ctx) + end Resolved(ScalarType(), tail = convert(SQLQuery, n)) end @@ -483,8 +521,12 @@ function resolve(n::PartitionNode, ctx) end fields[n.name] = RowType(FieldTypeMap(), t) end + private_fields = copy(t.private_fields) + if n.name !== nothing + push!(private_fields, n.name) + end q′ = Partition(by = by′, order_by = order_by′, frame = n.frame, name = n.name, tail = tail′) - Resolved(RowType(fields, group), tail = q′) + Resolved(RowType(fields, group, private_fields), tail = q′) end function resolve(n::SelectNode, ctx) @@ -532,16 +574,7 @@ function resolve(n::Union{WithNode, WithExternalNode}, ctx) v = get(ctx.cte_types, name, nothing) depth = 1 + (v !== nothing ? v[1] : 0) t = row_type(args′[i]) - cte_t = get(t.fields, name, EmptyType()) - if !(cte_t isa RowType) - throw( - ReferenceError( - REFERENCE_ERROR_TYPE.INVALID_TABLE_REFERENCE, - name = name, - path = get_path(ctx))) - - end - cte_types′ = Base.ImmutableDict(cte_types′, name => (depth, cte_t)) + cte_types′ = Base.ImmutableDict(cte_types′, name => (depth, t)) end ctx′ = ResolveContext(ctx, cte_types = cte_types′) tail′ = resolve(ctx′) diff --git a/src/serialize.jl b/src/serialize.jl index 16d8a8b5..02deb302 100644 --- a/src/serialize.jl +++ b/src/serialize.jl @@ -13,11 +13,11 @@ mutable struct SerializeContext <: IO end function serialize(s::SQLSyntax) - @dissect(s, WITH_CONTEXT(tail = (local s′), dialect = (local dialect), columns = (local columns))) || throw(IllFormedError()) + @dissect(s, WITH_CONTEXT(tail = (local s′), dialect = (local dialect), table = (local table))) || throw(IllFormedError()) ctx = SerializeContext(dialect, s′) serialize!(ctx) raw = String(take!(ctx.io)) - SQLString(raw, columns = columns, vars = ctx.vars) + SQLString(raw, table = table, vars = ctx.vars) end Base.write(ctx::SerializeContext, octet::UInt8) = diff --git a/src/strings.jl b/src/strings.jl index 122a8848..c0d2cce6 100644 --- a/src/strings.jl +++ b/src/strings.jl @@ -5,7 +5,8 @@ Serialized SQL query. -Parameter `columns` is a vector describing the output columns. +Parameter `table` is an optional `SQLTable` object describing the output of +the query. Parameter `vars` is a vector of query parameters (created with [`Var`](@ref)) in the order they are expected by the `DBInterface.execute()` function. @@ -55,11 +56,11 @@ SQLString(\""" """ struct SQLString <: AbstractString raw::String - columns::Union{Vector{SQLColumn}, Nothing} + table::Union{SQLTable, Nothing} vars::Vector{Symbol} - SQLString(raw; columns = nothing, vars = Symbol[]) = - new(raw, columns, vars) + SQLString(raw; table = nothing, vars = Symbol[]) = + new(raw, table, vars) end Base.ncodeunits(sql::SQLString) = @@ -88,8 +89,8 @@ Base.write(io::IO, sql::SQLString) = function PrettyPrinting.quoteof(sql::SQLString) ex = Expr(:call, nameof(SQLString), sql.raw) - if sql.columns !== nothing - push!(ex.args, Expr(:kw, :columns, Expr(:vect, Any[quoteof(col) for col in sql.columns]...))) + if sql.table !== nothing + push!(ex.args, Expr(:kw, :table, quoteof(sql.table))) end if !isempty(sql.vars) push!(ex.args, Expr(:kw, :vars, quoteof(sql.vars))) @@ -100,10 +101,11 @@ end function Base.show(io::IO, sql::SQLString) print(io, "SQLString(") show(io, sql.raw) - if sql.columns !== nothing - print(io, ", columns = ") - l = length(sql.columns) - print(io, l == 0 ? "[]" : l == 1 ? "[…1 column…]" : "[…$l columns…]") + if sql.table !== nothing + print(io, ", table = SQLTable(") + show(io, sql.table.name) + l = length(sql.table.columns) + print(io, ", ", l == 0 ? "[]" : l == 1 ? "[…1 column…]" : "[…$l columns…]", ")") end if !isempty(sql.vars) print(io, ", vars = ") diff --git a/src/translate.jl b/src/translate.jl index 84b3bd79..a616b1ef 100644 --- a/src/translate.jl +++ b/src/translate.jl @@ -39,10 +39,13 @@ function complete(a::Assemblage) end # Add a SELECT clause aligned with the exported references. -function complete_aligned(a::Assemblage, ctx) +function complete_aligned(a::Assemblage, ctx, expected_names = nothing) + names = collect(keys(a.cols)) + expected_names = something(expected_names, names) aligned = length(a.cols) == length(ctx.refs) && - all(a.repl[ref] === name for (name, ref) in zip(keys(a.cols), ctx.refs)) + all(a.repl[ref] === name for (name, ref) in zip(names, ctx.refs)) && + names == expected_names !aligned || return complete(a) if !@dissect(a.syntax, SELECT() || UNION()) alias = nothing @@ -54,9 +57,9 @@ function complete_aligned(a::Assemblage, ctx) subs = make_subs(a, alias) repl = Dict{SQLQuery, Symbol}() cols = OrderedDict{Symbol, SQLSyntax}() - for ref in ctx.refs - name = repl[ref] = a.repl[ref] - cols[name] = subs[ref] + for (expected_name, ref) in zip(expected_names, ctx.refs) + cols[expected_name] = subs[ref] + repl[ref] = expected_name end a′ = Assemblage(a.name, syntax, repl = repl, cols = cols) complete(a′) @@ -172,6 +175,7 @@ struct TranslateContext refs::Vector{SQLQuery} vars::Base.ImmutableDict{Tuple{Symbol, Int}, SQLSyntax} subs::Dict{SQLQuery, SQLSyntax} + partition::Union{SQLSyntax, Nothing} TranslateContext(; catalog, defs) = new(catalog, @@ -184,9 +188,10 @@ struct TranslateContext 0, SQLQuery[], Base.ImmutableDict{Tuple{Symbol, Int}, SQLSyntax}(), - Dict{Int, SQLSyntax}()) + Dict{Int, SQLSyntax}(), + nothing) - function TranslateContext(ctx::TranslateContext; tail = ctx.tail, cte_map = ctx.cte_map, knot = ctx.knot, refs = ctx.refs, vars = ctx.vars, subs = ctx.subs) + function TranslateContext(ctx::TranslateContext; tail = ctx.tail, cte_map = ctx.cte_map, knot = ctx.knot, refs = ctx.refs, vars = ctx.vars, subs = ctx.subs, partition = ctx.partition) new(ctx.catalog, tail, ctx.defs, @@ -197,7 +202,8 @@ struct TranslateContext knot, refs, vars, - subs) + subs, + partition) end end @@ -211,15 +217,11 @@ function allocate_alias(ctx::TranslateContext, alias::Symbol) end function translate(q::SQLQuery) - @dissect(q, (local q′) |> Linked(refs = (local refs)) |> WithContext(catalog = (local catalog), defs = (local defs))) || throw(IllFormedError()) + @dissect(q, (local q′) |> Linked(refs = (local refs)) |> WithContext(catalog = (local catalog), table = (local table), defs = (local defs))) || throw(IllFormedError()) ctx = TranslateContext(catalog = catalog, defs = defs) ctx′ = TranslateContext(ctx, refs = refs) base = assemble(q′, ctx′) - columns = nothing - if !isempty(refs) - columns = [SQLColumn(base.repl[ref]) for ref in refs] - end - c = complete_aligned(base, ctx′) + c = complete_aligned(base, ctx′, collect(keys(table))) with_args = SQLSyntax[] for cte_a in ctx.ctes !cte_a.external || continue @@ -238,7 +240,7 @@ function translate(q::SQLQuery) if !isempty(with_args) c = WITH(tail = c, args = with_args, recursive = ctx.recursive[]) end - WITH_CONTEXT(tail = c, dialect = ctx.catalog.dialect, columns = columns) + WITH_CONTEXT(tail = c, dialect = ctx.catalog.dialect, table = table) end function translate(q::SQLQuery, ctx) @@ -266,9 +268,10 @@ function translate(q, ctx::TranslateContext, subs::Dict{SQLQuery, SQLSyntax}) end function translate(n::AggregateNode, ctx) - args = translate(n.args, ctx) - filter = translate(n.filter, ctx) - AGG(n.name, args = args, filter = filter) + ctx′ = ctx.partition !== nothing ? TranslateContext(ctx, partition = nothing) : ctx + args = translate(n.args, ctx′) + filter = translate(n.filter, ctx′) + AGG(n.name, args = args, filter = filter, over = ctx.partition) end function translate(n::AsNode, ctx) @@ -427,26 +430,8 @@ function assemble(n::AppendNode, ctx) Assemblage(a_name, s, repl = repl, cols = dummy_cols) end -function assemble(n::AsNode, ctx) - refs′ = SQLQuery[] - for ref in ctx.refs - if @dissect(ref, (local tail) |> Nested()) - push!(refs′, tail) - else - push!(refs′, ref) - end - end - base = assemble(TranslateContext(ctx, refs = refs′)) - repl′ = Dict{SQLQuery, Symbol}() - for ref in ctx.refs - if @dissect(ref, (local tail) |> Nested()) - repl′[ref] = base.repl[tail] - else - repl′[ref] = base.repl[ref] - end - end - Assemblage(n.name, base.syntax, cols = base.cols, repl = repl′) -end +assemble(n::AsNode, ctx) = + assemble(ctx) function assemble(n::BindNode, ctx) vars′ = ctx.vars @@ -472,7 +457,6 @@ function assemble(n::DefineNode, ctx) for (f, i) in n.label_map tr_cache[f] = translate(n.args[i], ctx, subs) end - repl = Dict{SQLQuery, Symbol}() trns = Pair{SQLQuery, SQLSyntax}[] for ref in ctx.refs if @dissect(ref, nothing |> Get(name = (local name))) && name in keys(tr_cache) @@ -530,21 +514,12 @@ end assemble(::FromNothingNode, ctx) = assemble(nothing, ctx) -function unwrap_repl(a::Assemblage) - repl′ = Dict{SQLQuery, Symbol}() - for (ref, name) in a.repl - @dissect(ref, (local tail) |> Nested()) || error() - repl′[tail] = name - end - Assemblage(a.name, a.syntax, cols = a.cols, repl = repl′) -end - function assemble(n::FromTableExpressionNode, ctx) cte_a = ctx.ctes[ctx.cte_map[(n.name, n.depth)]] alias = allocate_alias(ctx, n.name) tbl = convert(SQLSyntax, (cte_a.qualifiers, cte_a.name)) s = FROM(AS(name = alias, tail = tbl)) - subs = make_subs(unwrap_repl(cte_a.a), alias) + subs = make_subs(cte_a.a, alias) trns = Pair{SQLQuery, SQLSyntax}[] for ref in ctx.refs push!(trns, ref => subs[ref]) @@ -634,7 +609,9 @@ function assemble(n::FromValuesNode, ctx) end function assemble(n::GroupNode, ctx) - has_aggregates = any(ref -> @dissect(ref, Agg() || Agg() |> Nested()), ctx.refs) + has_aggregates = + n.name === nothing && any(ref -> @dissect(ref, Agg()), ctx.refs) || + any(ref -> @dissect(ref, Nested(name = (local name))) && name == n.name, ctx.refs) if isempty(n.by) && !has_aggregates # NOOP: already processed in link() return assemble(nothing, ctx) end @@ -652,9 +629,9 @@ function assemble(n::GroupNode, ctx) if @dissect(ref, nothing |> Get(name = (local name))) @assert name in keys(n.label_map) push!(trns, ref => by[n.label_map[name]]) - elseif @dissect(ref, nothing |> Agg()) + elseif n.name === nothing && @dissect(ref, nothing |> Agg()) push!(trns, ref => translate(ref, ctx, subs)) - elseif @dissect(ref, (local tail = nothing |> Agg()) |> Nested()) + elseif @dissect(ref, (local tail) |> Nested(name = (local name))) && name == n.name push!(trns, ref => translate(tail, ctx, subs)) end end @@ -675,6 +652,33 @@ function assemble(n::GroupNode, ctx) return Assemblage(base.name, s, cols = cols, repl = repl) end +function assemble(n::IntoNode, ctx) + base = assemble(ctx) + if all(@dissect(ref, (local tail) |> Nested()) && tail in keys(base.repl) for ref in ctx.refs) + repl′ = Dict{SQLQuery, Symbol}() + for ref in ctx.refs + @dissect(ref, (local tail) |> Nested()) || error() + repl′[ref] = base.repl[tail] + end + return Assemblage(n.name, base.syntax, cols = base.cols, repl = repl′) + end + if !@dissect(base.syntax, SELECT() || UNION()) + base_alias = nothing + s = base.syntax + else + base_alias = allocate_alias(ctx, base) + s = FROM(AS(name = base_alias, tail = complete(base))) + end + subs = make_subs(base, base_alias) + trns = Pair{SQLQuery, SQLSyntax}[] + for ref in ctx.refs + @dissect(ref, (local tail) |> Nested()) || error() + push!(trns, ref => translate(tail, ctx, subs)) + end + repl, cols = make_repl_cols(trns) + Assemblage(base.name, s, cols = cols, repl = repl) +end + function assemble(n::IterateNode, ctx) ctx′ = TranslateContext(ctx, vars = Base.ImmutableDict{Tuple{Symbol, Int}, SQLSyntax}()) left = assemble(ctx) @@ -846,14 +850,13 @@ function assemble(n::PartitionNode, ctx) partition = PARTITION(by = by, order_by = order_by, frame = n.frame) trns = Pair{SQLQuery, SQLSyntax}[] has_aggregates = false + ctx′′ = TranslateContext(ctx′, partition = partition) for ref in ctx.refs if @dissect(ref, nothing |> Agg()) && n.name === nothing - @dissect(translate(ref, ctx′), AGG(name = (local name), args = (local args), filter = (local filter))) || error() - push!(trns, ref => AGG(; name, args, filter, over = partition)) + push!(trns, ref => translate(ref, ctx′′)) has_aggregates = true - elseif @dissect(ref, (local tail = nothing |> Agg()) |> Nested(name = (local name))) && name === n.name - @dissect(translate(tail, ctx′), AGG(name = (local name), args = (local args), filter = (local filter))) || error() - push!(trns, ref => AGG(; name, args, filter, over = partition)) + elseif @dissect(ref, (local tail) |> Nested(name = (local name))) && name === n.name + push!(trns, ref => translate(tail, ctx′′)) has_aggregates = true else push!(trns, ref => subs[ref]) @@ -883,22 +886,16 @@ function assemble(n::RoutedJoinNode, ctx) right = assemble(n.joinee, ctx) end if @dissect(right.syntax, (local joinee = (ID() || AS())) |> FROM()) && (!n.left || _outer_safe(right)) - for (ref, name) in right.repl - subs[ref] = right.cols[name] - end + right_alias = nothing if ctx.catalog.dialect.has_implicit_lateral lateral = false end else right_alias = allocate_alias(ctx, right) joinee = AS(name = right_alias, tail = complete(right)) - right_cache = Dict{Symbol, SQLSyntax}() - for (ref, name) in right.repl - subs[ref] = get(right_cache, name) do - ID(name = name, tail = right_alias) - end - end end + right_subs = make_subs(right, right_alias) + merge!(subs, right_subs) on = translate(n.on, ctx, subs) s = JOIN(joinee = joinee, on = on, left = n.left, right = n.right, lateral = lateral, tail = tail) trns = Pair{SQLQuery, SQLSyntax}[] diff --git a/src/types.jl b/src/types.jl index 856821ed..24d0c369 100644 --- a/src/types.jl +++ b/src/types.jl @@ -13,17 +13,21 @@ PrettyPrinting.quoteof(::EmptyType) = Expr(:call, nameof(EmptyType)) struct ScalarType <: AbstractSQLType + ScalarType() = + new() end -PrettyPrinting.quoteof(::ScalarType) = +function PrettyPrinting.quoteof(t::ScalarType) Expr(:call, nameof(ScalarType)) +end struct RowType <: AbstractSQLType fields::OrderedDict{Symbol, Union{ScalarType, RowType}} group::Union{EmptyType, RowType} + private_fields::Set{Symbol} - RowType(fields, group = EmptyType()) = - new(fields, group) + RowType(fields, group = EmptyType(), private_fields = Set{Symbol}()) = + new(fields, group, private_fields) end const FieldTypeMap = OrderedDict{Symbol, Union{ScalarType, RowType}} @@ -32,8 +36,8 @@ const GroupType = Union{EmptyType, RowType} RowType() = RowType(FieldTypeMap()) -RowType(fields::Pair{Symbol, <:AbstractSQLType}...; group = EmptyType()) = - RowType(FieldTypeMap(fields), group) +RowType(fields::Pair{Symbol, <:AbstractSQLType}...; group = EmptyType(), private_fields = Set{Symbol}()) = + RowType(FieldTypeMap(fields), group, private_fields) function PrettyPrinting.quoteof(t::RowType) ex = Expr(:call, nameof(RowType)) @@ -43,6 +47,9 @@ function PrettyPrinting.quoteof(t::RowType) if !(t.group isa EmptyType) push!(ex.args, Expr(:kw, :group, quoteof(t.group))) end + if !isempty(t.private_fields) + push!(ex.args, Expr(:kw, :private_fields, t.private_fields)) + end ex end @@ -62,16 +69,20 @@ function Base.intersect(t1::RowType, t2::RowType) return t1 end fields = FieldTypeMap() + private_fields = Set{Symbol}() for f in keys(t1.fields) if f in keys(t2.fields) t = intersect(t1.fields[f], t2.fields[f]) if !isa(t, EmptyType) fields[f] = t + if f in t1.private_fields && f in t2.private_fields + push!(private_fields, f) + end end end end group = intersect(t1.group, t2.group) - RowType(fields, group) + RowType(fields, group, private_fields) end @@ -91,7 +102,7 @@ function Base.issubset(t1::RowType, t2::RowType) return true end for f in keys(t1.fields) - if !(f in keys(t2.fields) && issubset(t1.fields[f], t2.fields[f])) + if !(f in keys(t2.fields) && issubset(t1.fields[f], t2.fields[f]) && (!(f in t1.private_fields) || f in t2.private_fields)) return false end end