diff --git a/Manifest.toml b/Manifest.toml index 5317cf9..a6421e7 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.9.3" manifest_format = "2.0" -project_hash = "3fee985c9eca9c1165906a81b33e7ff0c8e89188" +project_hash = "95d68df1d1fba8de2378675f0b447fab215f5eb8" [[deps.Accessors]] deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Test"] @@ -102,6 +102,11 @@ git-tree-sha1 = "fc08e5930ee9a4e03f84bfb5211cb54e7769758a" uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" version = "0.12.10" +[[deps.Combinatorics]] +git-tree-sha1 = "08c8b6831dc00bfea825826be0bc8336fc369860" +uuid = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +version = "1.0.2" + [[deps.CommonSubexpressions]] deps = ["MacroTools", "Test"] git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" @@ -357,10 +362,16 @@ version = "0.72.9+1" [[deps.Gen]] deps = ["Compat", "DataStructures", "Distributions", "ForwardDiff", "FunctionalCollections", "JSON", "LinearAlgebra", "MacroTools", "Parameters", "Random", "ReverseDiff", "SpecialFunctions"] -git-tree-sha1 = "9878ff4ab1990f5647e89b4228a3c9da5f0e69c7" +path = "/home/dg963/GalileoEvents/env.d/jenv/dev/Gen" uuid = "ea4f424c-a589-11e8-07c0-fd5c91b9da4a" version = "0.4.6" +[[deps.GenParticleFilters]] +deps = ["Distributions", "Gen", "Parameters", "Statistics"] +git-tree-sha1 = "1009fe501115947ba57a9e9e90a3acd3e8d476bb" +uuid = "56b76ac4-72ef-411e-b419-6d312ed86a6f" +version = "0.1.8" + [[deps.Gen_Compose]] deps = ["DataStructures", "FileIO", "Gen", "JLD2", "Revise", "UnPack"] git-tree-sha1 = "b2ac8a12976eb7e459d39c29e7dbbe5ec64199c8" @@ -733,8 +744,8 @@ uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" version = "2.7.2" [[deps.PhyBullet]] -deps = ["Accessors", "Conda", "Distributions", "DocStringExtensions", "Gen", "Parameters", "PhySMC", "Plots", "PyCall", "Revise", "StaticArrays", "UnicodePlots"] -git-tree-sha1 = "f553bbf8cfdc3a291380ee17f600f818f6cad054" +deps = ["Accessors", "Conda", "Distributions", "DocStringExtensions", "Gen", "Parameters", "PhySMC", "PyCall", "Revise", "StaticArrays", "UnicodePlots"] +git-tree-sha1 = "9fddd996bd0e0fded73a3a15d7f579e1f76cc46f" repo-rev = "master" repo-url = "https://github.com/CNCLgithub/PhyBullet" uuid = "63daae69-5b14-439d-ac6f-096429ca839b" diff --git a/Project.toml b/Project.toml index 2a186be..af12d93 100644 --- a/Project.toml +++ b/Project.toml @@ -5,13 +5,17 @@ version = "0.1.0" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Gen = "ea4f424c-a589-11e8-07c0-fd5c91b9da4a" +GenParticleFilters = "56b76ac4-72ef-411e-b419-6d312ed86a6f" Gen_Compose = "c1ef4dca-b0a6-4a35-b24b-46cbf3979a16" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" PhyBullet = "63daae69-5b14-439d-ac6f-096429ca839b" PhySMC = "79c1e2f5-7911-41a0-b248-4858717ddd79" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" diff --git a/src/gms/cp_gm_pb.jl b/src/gms/cp_gm_pb.jl new file mode 100644 index 0000000..61ce186 --- /dev/null +++ b/src/gms/cp_gm_pb.jl @@ -0,0 +1,294 @@ +using Revise + +export CPParams, + CPState, + cp_model, + EventRelation, + Collision + +using LinearAlgebra:norm +using Combinatorics + +## Changepoint Model + components + +""" +Event types +""" +abstract type EventRelation end + +struct Collision <: EventRelation + a::RigidBody + b::RigidBody +end + +struct NoEvent <: EventRelation end + + +""" +holds parameters for change point model +""" +struct CPParams <: GMParams + # prior + material_prior::MaterialPrior + physics_prior::PhysPrior + # event relations + event_concepts::Vector{Type{<:EventRelation}} + # simulation + sim::BulletSim + template::BulletState + n_objects::Int64 + obs_noise::Float64 + death_factor::Float64 +end + +""" +constructs parameter struct for change point model +""" +function CPParams(client::Int64, objs::Vector{Int64}, + mprior::MaterialPrior, pprior::PhysPrior, + event_concepts::Vector{Type{<:EventRelation}}, + obs_noise::Float64=0., + death_factor=10.) + # configure simulator with the provided + # client id + sim = BulletSim(;client=client) + # These are the objects of interest in the scene + rigid_bodies = RigidBody.(objs) + # Retrieve the default latents for the objects + # as well as their initial positions + # Note: alternative latents will be suggested by the `prior` + template = BulletState(sim, rigid_bodies) + + CPParams(mprior, pprior, event_concepts, sim, template, length(objs), obs_noise, death_factor) +end + +""" +Current state of the change point model, simulation state and event state +""" +struct CPState <: GMState + bullet_state::BulletState + active_events::Set{Int64} +end + +## PRIOR + +""" +initalizes prior beliefs about mass, friction and resitution of the given objects +""" +@gen function cp_object_prior(ls::RigidBodyLatents, gm::CPParams) + # sample material + mi = @trace(categorical(gm.material_prior.material_weights), :material) + + # sample physical properties + phys_params = gm.physics_prior + mass_mu, mass_sd = phys_params.mass + mass = @trace(trunc_norm(mass_mu, mass_sd, 0., Inf), :mass) + fric_mu, fric_sd = phys_params.friction + friction = @trace(trunc_norm(fric_mu,fric_sd, 0., 1.), :friction) + res_low, res_high = phys_params.restitution + restitution = @trace(uniform(res_low, res_high), :restitution) + + # package + new_ls = setproperties(ls.data; + mass = mass, + lateralFriction = friction, + restitution = restitution) + new_latents::RigidBodyLatents = RigidBodyLatents(new_ls) + return new_latents +end + +""" +initializes belief about all objects and events +""" +@gen function cp_prior(params::CPParams) + # initialize the kinematic state + latents = params.template.latents + params_filled = Fill(params, length(latents)) + new_latents = @trace(Gen.Map(cp_object_prior)(latents, params_filled), :objects) + bullet_state = setproperties(params.template; latents = new_latents) + + # initialize the event state + active_events = Set{Int64}() + + init_state = CPState(bullet_state, active_events) + return init_state +end + +""" +Bernoulli weight that event relation holds +""" +function predicate(t::Type{Collision}, a::RigidBodyState, b::RigidBodyState) + if norm(Vector(a.linear_vel)-Vector(b.linear_vel)) < 0.01 + return 0 + end + + a_dim = a.aabb[2] - a.aabb[1] + b_dim = b.aabb[2] - b.aabb[1] + d = norm(Vector(a.position)-Vector(b.position))-norm((a_dim+b_dim)/2) # l2 distance + clamp(exp(-15d), 1e-3, 1 - 1e-3) +end + + +""" +update latents of a single element +""" +@gen function update_latents(latents::BulletElemLatents) + new_mass = @trace(trunc_norm(latents.data.mass, .1, 0., Inf), :mass) + new_restitution = @trace(trunc_norm(latents.data.restitution, .1, 0., 1.), :restitution) + + new_latents = setproperties(latents.data; + mass = new_mass, + restitution = new_restitution) + new_latents = RigidBodyLatents(new_latents) + return new_latents +end + +""" +in case of collision: Gaussian drift update of mass and restitution +""" +@gen function _collision_clause(pair_idx::Vector{Int64}, latents::Vector{BulletElemLatents}) + latents[pair_idx[1]] = @trace(update_latents(latents[pair_idx[1]]), :new_latents_a) + latents[pair_idx[2]] = @trace(update_latents(latents[pair_idx[2]]), :new_latents_b) + return latents +end + +function clause(::Type{Collision}) + _collision_clause +end + +""" +in case of no event: no change +""" +@gen function _no_event_clause(pair_idx, latents::Vector{BulletElemLatents}) + return latents +end + +function clause(::Type{NoEvent}) + _no_event_clause +end + +event_concepts = Type{<:EventRelation}[NoEvent, Collision] +switch = Gen.Switch(map(clause, event_concepts)...) + +""" +TODO: this function was intended to check if some event relations are impossible to be created a certain time step +""" +function valid_relations(state::CPState, event_concepts::Vector{Type{EventRelation}}) + return event_concepts + # TODO: replace by map in the end + for EventRelation in event_concepts + # TODO: decide if valid + end +end + +@gen function event_switch(clause, events, start_event_idx, pair, latents) + switch = Gen.Switch(map(clause, events)...) + return switch(start_event_idx, pair, latents) +end + +""" +map possible events to weight vector for birth decision using the predicates +""" +function calculate_predicates(obj_states) + object_pairs = collect(combinations(obj_states, 2)) + pair_idx = repeat(collect(combinations(1:length(obj_states), 2)), length(event_concepts)) + pair_idx = [[0,0], pair_idx...] # [0,0] for no event + + # break up to two lines + predicates = [predicate(event_type, a, b) for event_type in event_concepts for (a, b) in object_pairs if event_type != NoEvent] # NoEvent excluded and added in weights + event_ids = vcat(1, repeat(2:length(event_concepts), inner=length(object_pairs))) # 1 for NoEvent + return predicates, event_ids, pair_idx +end + +""" +transform predicates for pairs of objects into a probability vector that adds to 1, including one weight for NoEvent at the first position +""" +function normalize_weights(weights, active_events) + for idx in active_events # active events should not be born again + weights[idx-1] = 0 + end + weights = [max(0, 1 - sum(weights)), weights...] # first element for NoEvent + # TODO: objects that are already involved in some events should not be involved in other event types as well + return weights ./ sum(weights) +end + +""" +similar to normalize_weights but for death of events +""" +function calculate_death_weights(predicates, active_events, start_event_idx, death_factor) + can_die(idx) = idx+1 in active_events && idx+1 != start_event_idx + # dying has a much lower chance of being born + get_weight(idx) = can_die(idx) ? max(1. - predicates[idx] * death_factor, 0.) : 0.0 + weights = [get_weight(idx) for idx in 1:length(predicates)] + weights = [max(0, 1-sum(weights)), weights...] # no event at index 1 + return weights ./ sum(weights) +end + +""" +updates active events in a functional form +add=True -> add event to set of active events +add=False -> remove event from set of active events +""" +function update_active_events(active_events::Set{Int64}, event_idx::Int64, add::Bool) + if event_idx == 1 + return active_events + end + if add + return union(active_events, Set([event_idx])) + else + return setdiff(active_events, Set([event_idx])) + end +end + + +""" +iterate over event concepts and evaluate predicates for newly activated events +""" +@gen function event_kernel(active_events, bullet_state, death_factor) + predicates, event_ids, pair_idx = calculate_predicates(bullet_state.kinematics) + weights = normalize_weights(copy(predicates), active_events) + start_event_idx = @trace(categorical(weights), :start_event_idx) # up to one event is born + + updated_latents = @trace(switch(event_ids[start_event_idx], pair_idx[start_event_idx], bullet_state.latents), :event) + bullet_state = setproperties(bullet_state; latents = updated_latents) + active_events = update_active_events(active_events, start_event_idx, true) + + weights = calculate_death_weights(predicates, active_events, start_event_idx, death_factor) + end_event_idx = @trace(categorical(weights), :end_event_idx) # up to one active event dies + active_events = update_active_events(active_events, end_event_idx, false) + + return active_events, bullet_state +end + +""" +for one object, observe the noisy position in every dimension +""" +@gen function observe_position(k::RigidBodyState, noise::Float64) + @trace(broadcasted_normal(k.position, noise), :positions) +end + +""" +run event and physics kernel for one time step and observe noisy positions +""" +@gen function kernel(t::Int, prev_state::CPState, params::CPParams) + active_events, bullet_state = @trace(event_kernel(prev_state.active_events, + prev_state.bullet_state, + params.death_factor), :events) + + bullet_state::BulletState = PhySMC.step(params.sim, bullet_state) + @trace(Gen.Map(observe_position)(bullet_state.kinematics, Fill(params.obs_noise, params.n_objects)), :observe) + + return CPState(bullet_state, active_events) +end + +""" +generate physical scene with changepoints in the belief state +""" +@gen function cp_model(t::Int, params::CPParams) + # initalize the kinematic and event state + init_state = @trace(cp_prior(params), :prior) + + # unfold the event and kinematic state over time + states = @trace(Gen.Unfold(kernel)(t, init_state, params), :kernel) + return states +end diff --git a/src/gms/gms.jl b/src/gms/gms.jl index db12e91..164d556 100644 --- a/src/gms/gms.jl +++ b/src/gms/gms.jl @@ -80,4 +80,4 @@ struct PhysPrior end include("mc_gm.jl") -# include("cp_gm.jl") +include("cp_gm_pb.jl") diff --git a/src/gms/mc_gm.jl b/src/gms/mc_gm.jl index dae9ea2..ab38f66 100644 --- a/src/gms/mc_gm.jl +++ b/src/gms/mc_gm.jl @@ -105,7 +105,6 @@ end @gen function mc_gm(t::Int, gm::MCParams) init_state = @trace(mc_prior(gm), :prior) # simulate `t` timesteps - println(init_state) states = @trace(Gen.Unfold(kernel)(t, init_state, gm), :kernel) return states end diff --git a/test/gms/cp_gm.jl b/test/gms/cp_gm.jl new file mode 100644 index 0000000..be89e13 --- /dev/null +++ b/test/gms/cp_gm.jl @@ -0,0 +1,236 @@ +using Revise +using Gen +using GalileoEvents +using Plots +ENV["GKSwstype"]="160" # fixes some plotting warnings + +mass_ratio = 2.0 +obj_frictions = (0.3, 0.3) +obj_positions = (0.5, 1.2) + +mprior = MaterialPrior([unknown_material]) +pprior = PhysPrior((3.0, 10.0), # mass + (0.5, 10.0), # friction + (0.2, 1.0)) # restitution + +obs_noise = 0.05 +t = 80 + +fixed_prior_cm = Gen.choicemap() +fixed_prior_cm[:prior => :objects => 1 => :mass] = 2. +fixed_prior_cm[:prior => :objects => 2 => :mass] = 1. +fixed_prior_cm[:prior => :objects => 1 => :friction] = 0.5 +fixed_prior_cm[:prior => :objects => 2 => :friction] = 0.5 +fixed_prior_cm[:prior => :objects => 1 => :restitution] = 0.2 +fixed_prior_cm[:prior => :objects => 2 => :restitution] = 0.2 + +function forward_test() + client, a, b = ramp(mass_ratio, obj_frictions, obj_positions) + event_concepts = Type{<:EventRelation}[Collision] + cp_params = CPParams(client, [a,b], mprior, pprior, event_concepts, obs_noise) + + trace, weight = Gen.generate(cp_model, (t, cp_params)); + @show weight + #display(get_choices(trace)) +end + +function add_rectangle!(plt, xstart, xend, y; height=0.8, color=:blue) + xvals = [xstart, xend, xend, xstart, xstart] + yvals = [y, y, y+height, y+height, y] + plot!(plt, xvals, yvals, fill=true, seriestype=:shape, fillcolor=color, linecolor=color) +end + +get_x2(trace, t) = get_retval(trace)[t].bullet_state.kinematics[2].position[1] + +function visualize_active_events() + client, a, b = ramp(mass_ratio, obj_frictions, obj_positions) + event_concepts = Type{<:EventRelation}[Collision] + cp_params = CPParams(client, [a,b], mprior, pprior, event_concepts, obs_noise) + + num_traces = 50 + plt = plot(legend=false, xlim=(0, t), ylim=(1, num_traces+1), yrotation=90, ylabel="Trace", yticks=false, xlabel="Time step") + collision_t = nothing + for i in 1:num_traces + if i % 10 == 0 + @show i + end + trace, _ = Gen.generate(cp_model, (t, cp_params), fixed_prior_cm); + + start = nothing + first_x = i==1 ? get_x2(trace, 1) : nothing # only look for collision in first trace + for j in 1:t + if trace[:kernel=>j=>:events=>:start_event_idx]==2 + start = j + end + if trace[:kernel=>j=>:events=>:end_event_idx]==2 + finish = j + add_rectangle!(plt, start, finish, i) + end + if first_x !== nothing && abs(first_x - get_x2(trace, j)) > 0.001 + collision_t = j + first_x = nothing + end + end + + end + vline!(plt, [collision_t], linecolor=:red, linewidth=2, label="Vertical Line") + savefig(plt, "test/gms/plots/events.png") +end + +# constrained generation, event 2 must start at timestep 10 +function constrained_test() + client, a, b = ramp(mass_ratio, obj_frictions, obj_positions) + event_concepts = Type{<:EventRelation}[Collision] + cp_params = CPParams(client, [a,b], mprior, pprior, event_concepts, obs_noise) + + #addr = 10 => :events => :start_event_idx + #cm = Gen.choicemap(addr => 2) + trace, weight = Gen.generate(cp_model, (t, cp_params), fixed_prior_cm) + @show weight + #display(get_choices(trace)) +end + +# update priors +function update_test() + t = 120 + + client, a, b = ramp(mass_ratio, obj_frictions, obj_positions) + event_concepts = Type{<:EventRelation}[Collision] + cp_params = CPParams(client, [a,b], mprior, pprior, event_concepts, obs_noise) + trace, _ = Gen.generate(cp_model, (t, cp_params)) + + addr = :prior => :objects => 1 => :mass + cm = Gen.choicemap(addr => trace[addr] + 3) + trace2, _ = Gen.update(trace, cm) + + # compare final positions + t=120 + pos1 = Vector(get_retval(trace)[t].bullet_state.kinematics[1].position) + pos2 = Vector(get_retval(trace2)[t].bullet_state.kinematics[1].position) + @assert pos1 != pos2 + + return trace, trace2 +end + +# change event start +function update_test_2() + + client, a, b = ramp(mass_ratio, obj_frictions, obj_positions) + event_concepts = Type{<:EventRelation}[Collision] + cp_params = CPParams(client, [a,b], mprior, pprior, event_concepts, obs_noise) + + # generate initial trace + trace, ls = Gen.generate(cp_model, (t, cp_params), fixed_prior_cm) + + # find first collision in the trace + start_event_indices = [trace[:kernel=>i=>:events=>:start_event_idx] for i in 1:t] + t1 = findfirst(x -> x == 2, start_event_indices) + @show ls + choices = get_choices(trace) + display(get_submap(choices, :kernel => t1 => :events)) + + # TODO: validate existence of event + # move first collision five steps earlier + cm = choicemap() + cm[:kernel => t1 => :events => :start_event_idx] = 1 + cm[:kernel => t1 - 5 => :events => :start_event_idx] = 2 + trace2, ls2, _... = Gen.update(trace, cm) + #@show ls2 + choices = get_choices(trace2) + #display(get_submap(choices, :kernel => t1 => :events)) + #display(get_submap(choices, :kernel => t1 -5 => :events)) + + # the keys have to be enumerated, subsets do not work + trace3, delta_s, _... = Gen.regenerate(trace, select( + :kernel => t1 => :events => :event => :new_latents_a => :mass,)) + #:kernel => t1 => :events => :event => :new_latents_a => :restitution)) + + @show delta_s + choices2 = get_choices(trace3) + display(get_submap(choices2, :kernel => t1 => :events)) + + + for i in 1:t + if project(trace3, select(:kernel => i)) == -Inf + @show i + display(get_submap(choices2, :kernel => i => :events)) + end + end + @show t1 + @show project(trace3, select(:kernel)) + @assert delta_s != -Inf + @assert !isnan(delta_s) + + #return trace, trace2 +end + +# redraw latents at same event start +function update_test_3() + + client, a, b = ramp(mass_ratio, obj_frictions, obj_positions) + event_concepts = Type{<:EventRelation}[Collision] + cp_params = CPParams(client, [a,b], mprior, pprior, event_concepts, obs_noise) + + # generate initial trace + trace, _ = Gen.generate(cp_model, (t, cp_params)) + + # find first collision in the trace + start_event_indices = [trace[:kernel=>i=>:events=>:start_event_idx] for i in 1:t] + t1 = findfirst(x -> x == 2, start_event_indices) + + # in future maybe gaussian rw + trace2, delta_s, _... = Gen.regenerate(trace, select(:kernel => t1 => :events => :event)) + + @assert delta_s != -Inf + @assert delta_s != NaN + + return trace, trace2 +end + + +# test switch combinator in terms of gen's reaction to proposed changes + +# toy model for dealing with complexing +# random walk with 2 delta functions (gaussian vs uniform) chosen by switch +# initial trace is changed by a mh proposal for switch index +# static first, unfold complexity second step + +@gen function function1() + v ~ normal(0., 1.) +end + +@gen function function2() + v ~ uniform(-1., 1.) +end + +switch = Gen.Switch(function1, function2) + +@gen function switch_model_static() + function_idx = @trace(categorical([0.5, 0.5]), :function) + x = @trace(switch(function_idx), :x) + y = @trace(normal(x, 1.), :y) +end + +function switch_test_static() + # unconstrained generation + trace, _ = Gen.generate(switch_model_static, ()) + display(get_choices(trace)) + + # constrained generation + cm = Gen.choicemap(:function => 1) + trace2, _ = Gen.generate(switch_model_static, (), cm) + display(get_choices(trace2)) + + # update and regenerate trace + trace3, _ = Gen.update(trace, cm) + trace4, _ = Gen.regenerate(trace3, select(:x)) + display(get_choices(trace4)) +end + +#forward_test() +visualize_active_events() +#constrained_test() +#update_test() +#update_test_2() +#update_test_3() +#switch_test_static() \ No newline at end of file diff --git a/test/gms/plots/events.png b/test/gms/plots/events.png new file mode 100644 index 0000000..82c93fa Binary files /dev/null and b/test/gms/plots/events.png differ diff --git a/test/gms/plots/x_positions.png b/test/gms/plots/x_positions.png new file mode 100644 index 0000000..74b83e3 Binary files /dev/null and b/test/gms/plots/x_positions.png differ diff --git a/test/particle_filter.jl b/test/particle_filter.jl new file mode 100644 index 0000000..8c94a46 --- /dev/null +++ b/test/particle_filter.jl @@ -0,0 +1,148 @@ +using Revise +using GalileoEvents +using Gen +using Printf +using Plots +using GenParticleFilters +using Distributions +ENV["GKSwstype"]="160" # fixes some plotting warnings + +""" +gen_trial + +Generates a trial and returns the generation parameters, the true trace and the observations +""" +function gen_trial() + # configure model paramaters + mass_ratio = 1. # rand(Gamma(5.0, 0.25)) # is overwritten anyway + obj_frictions = (rand(Uniform(0.1, 0.2)), rand(Uniform(0.1, 0.2))) + obj_positions = (rand(Uniform(0.5, 0.6)), rand(Uniform(1.1, 1.2))) + @show obj_positions + mprior = MaterialPrior([unknown_material]) + pprior = PhysPrior((3.0, 10.0), # mass + (0.5, 0.25), # friction + (0.2, 1.0)) # restitution + client, a, b = ramp(mass_ratio, obj_frictions, obj_positions) + event_concepts = Type{<:EventRelation}[Collision] + obs_noise = 0.05 + + cp_params = CPParams(client, [a,b], mprior, pprior, event_concepts, obs_noise) + + t = 70 + + # run model forward + trace, _ = Gen.generate(cp_model, (t, cp_params)); + + # collect observations + choices = get_choices(trace) + observations = Vector{Gen.ChoiceMap}(undef, t) + for i = 1:t + prefix = :kernel => i => :observe + cm = choicemap() + set_submap!(cm, prefix, get_submap(choices, prefix)) + observations[i] = cm + end + + return t, cp_params, trace, observations +end + +@gen function proposal(trace) + # find first collision in the trace + t = get_args(trace)[1] + start_event_indices = [trace[:kernel=>i=>:events=>:start_event_idx] for i in 1:t] + t1 = findfirst(x -> x == 2, start_event_indices) + + if !isnothing(t1) + # in future, maybe gaussian rw + trace2, delta_s, _... = Gen.regenerate(trace, select( + :kernel => t1 => :events => :event => :new_latents_a => :mass, + :kernel => t1 => :events => :event => :new_latents_b => :mass + )) + return trace2, delta_s + end + + return trace, 0 +end + +""" +do_inference + +Runs particle filter inference on a model and given observations +""" +function do_inference(t::Int, params::CPParams, observations::Vector{ChoiceMap}, n_particles::Int = 100, ess_thresh=0.5) + # initialize particle filter + state = pf_initialize(cp_model, (0, params), EmptyChoiceMap(), n_particles) + + # Then increment through each observation step + for t in 1:length(observations) + # Update filter state with new observation at timestep t + pf_update!(state, (t, params), (UnknownChange(), NoChange()), observations[t]) + + step_time = @elapsed begin + # Resample and rejuvenate if the effective sample size is too low + if effective_sample_size(state) < ess_thresh * n_particles + # Perform residual resampling, pruning low-weight particles + pf_resample!(state, :residual) + end + # Perform a rejuvenation move on past choices + #rejuv_sel = select(:kernel => t => :events => :event) + #pf_rejuvenate!(state, mh, (rejuv_sel,)) + + kern(trace) = move_reweight(trace, proposal, ()) + pf_move_reweight!(state, kern) + end + + if t % 10 == 0 + @printf "%s time steps completed (last step was %0.2f seconds)\n" t step_time + end + end + + # return the "unweighted" set of traces after t steps + return get_traces(state), get_log_weights(state) +end + + +function plot_trace(tr::Gen.Trace, title="Trajectory") + (t, _) = get_args(tr) + # get the prior choice for the two masses + choices = get_choices(tr) + masses = [round(choices[:prior => :objects => i => :mass], digits=2) for i in 1:2] + + # get the x positions + states = get_retval(tr) + #diplsay(states) + xs = [map(st -> st.bullet_state.kinematics[i].position[1], states) for i = 1:2] + + # return plot + plot(1:t, xs, title=title, labels=["ramp: $(masses[1])" "table: $(masses[2])"], xlabel="t", ylabel="x") +end + +""" +plot_traces(truth::Gen.DynamicDSLTrace, traces::Vector{Gen.DynamicDSLTrace}) + +Display the observed and final simulated trajectory as well as distributions for latents and the score +""" +function plot_traces(truth::Gen.DynamicDSLTrace, traces::Vector{Gen.DynamicDSLTrace}, weights) + observed_plt = plot_trace(truth, "True trajectory") + simulated_plt = plot_trace(last(traces), "Last trace") + + (t, _) = get_args(truth) + num_traces = length(traces) + mass_logs = [[t[:prior => :objects => i => :mass] for t in traces] for i in 1:2] + + scores_plt = plot(1:num_traces, weights, title="Scores", xlabel="trace number", ylabel="log score") + mass_plts = [Plots.histogram(1:num_traces, mass_logs[i], title="Mass $(i == 1 ? "Ramp object" : "Table object")", legend=false) for i in 1:2] + ratio_plt = Plots.histogram(1:num_traces, mass_logs[1]./mass_logs[2], title="mass ramp object / mass table object", legend=false) + plt = plot(observed_plt, simulated_plt, mass_plts..., scores_plt, ratio_plt, size=(1200, 800)) + savefig(plt, "test/plots/particle_filter.png") +end + +# data generation +t, params, truth, observations = gen_trial() +#display(get_choices(truth)) + +# inference +traces, weights = do_inference(t, params, observations, 25) + +# visualize results +plot_traces(truth, traces, weights) \ No newline at end of file diff --git a/test/plots/particle_filter.png b/test/plots/particle_filter.png new file mode 100644 index 0000000..03718af Binary files /dev/null and b/test/plots/particle_filter.png differ