Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add POMDP example and change HMM example to work with DiscreteTransition and DirichletCollection #12

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

wouterwln
Copy link
Member

@FraserP117 will check the POMDP tutorial, but I think it is in a mergeable state already

@FraserP117 FraserP117 self-assigned this Feb 7, 2025
@FraserP117
Copy link

Many thanks @wouterwln

This is really great stuff! I've run the example and played around a bit. I'll lay out my thoughts and queries regarding your explanations below. Note: I have only looked at the POMDP example thus far.

  1. Environment Setup: I found the env setup with RxEnvironments perfectly transparent. The link to the docs on WindyGridWorld was very helpful.

  2. Model Setup:

    2.1. You say: "We will use the DiscreteTransition node in RxInfer to define the state transition model." However I could not find the DiscreteTransition node definition or (dev/stable) docs in RxInfer or ReactiveMP. My guess is that this is what the Transition node used to be called? In spite of the fact that you consistently call DiscreteTransition everything works fine so I fell like I'm misunderstanding something.

    2.2. I would really love to see the docs on the Transition/DiscreteTransition node. Regarding your explanation of the Transition node, I'm not sure what an "interface" is though I take it that in and out from: out ~ Transition(in, parameters, additional_interfaces...) are tensors of some kind: in being the prior on the transition model: B and out being the posterior for the transition model - for example? Hence I wonder if an "interface" should just be taken to mean any argument to the node?

    2.3. Regarding the model definition, I can basically work out what everything is here though it would be nice to add in comments that explicitly label each variable.

  3. Variational Constraints: I found these straightforward and your explanation sufficed to make it so. My only issue is that I don't understand why we have:

init = @initialization begin
    q(A) = DirichletCollection(diageye(25) .+ 0.1)
    q(B) = DirichletCollection(ones(25, 25, 4))
end

instead of

init = @initialization begin
    q(A) = DirichletCollection(diageye(36) .+ 0.1)
    q(B) = DirichletCollection(ones(36, 36, 4))
end

Given that the WindyGridWorld is a (6, 6) grid.

  1. Priors on Model Params: This makes sense and I think you explained it well. Again, I question why p_A, p_B and the 3 methods defined here use 25 instead of 36. Perhaps because the goal is to be found within a smaller radius than the full (6, 6) grid, this means that we don't need to model the whole grid? Perhaps you originally made the grid (5, 5) and then changed your mind? More likely, I'm just misunderstanding something!

  2. Main Loop: I think this is fairly straightforward to understand, however I have some thoughts. I like the explanation about the chosen order of operations: how we will first take an action, observe and then update belief. Regarding the actual call to infer() I absolutely do not understand why we have:

...
m_A = mean(p_A),
m_B = mean(p_B)
...

I get that you say: "The real reason we did this is because we do not want messages from the future to influence the model parameters, instead only learning the model parameters from past data. " Fair enough and thank you for that context, otherwise I would have no idea what was going on here. I still have no appreciation as to why future messages would influence inference over the model parameters, were this step to be omitted. I don't even know why that would be such a bad thing, in principle. Perhaps because the experimental setup requires it in this specific case?

There are one or two trivial spelling mistakes but that's a real nit-pick. What I think would help me most now is to see the docs on Transition/DiscreteTransition, in addition to digesting the source code. The latter of which I can do at any time.

Many thanks again! I'm keen to help out wherever on this.

],
"source": [
"include(\"env.jl\")\n",
"env = RxEnvironment(WindyGridWorld((0, 1, 1, 1, 0), [], (4, 3)))\n",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this is only defining a (5, 5) grid? yet the rest of the text assumes a (6, 6) grid?

"source": [
"## Model Setup\n",
"\n",
"First, we'll define our POMDP model structure. We will use the `DiscreteTransition` node in `RxInfer` to define the state transition model. The `DiscreteTransition` node is a special node that accepts any number of `Categorical` distributions as input, and outputs a `Categorical` distribution. This means that we can use it to define a state transition model that accepts the previous state and the control as `Categorical` random variables, but we can also use it to define our observation model! Furthermore, the `DiscreteTransition` node can be used both for parameter inference and for inference-as-planning, isn't that neat?"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't find any such node in RxInfer, though there is a Transition node defined in ReactiveMP.

"output_type": "display_data"
}
],
"source": [

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

25 or 36?

"cell_type": "markdown",
"metadata": {},
"source": [
"Now, in order to use this model, we have to define the priors for the model parameters. The WindyGridworld environment has a 6-by-6 grid, so we need to instantiate a prior 36-by-36 transition matrices for every control! That's quite a lot of parameters, but as we will see, `RxInfer` will handle this just fine. We will give our agent a control space of 4 actions, so we need to instantiate 4 transition matrices. Furthermore, we have to transform the output from the environment to a 1-in-36 index, and the controls from a 1-in-4 index to a direction tuple."

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like the text assumes a (6, 6) grid, while the code actually uses a (5, 5).

"cell_type": "markdown",
"metadata": {},
"source": [
"`RxEnvironments.jl` is a package that allows us to easily communicate between our agent and our environment. We can senc actions to the environment, and the environment will automatically respond with the corresponding observations. In order to access these in our model, we can subscribe to the observations and then use the `data` function to access the last observation."

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"sync" not "senc"?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants