from dataclasses import dataclass
import numpy as np
@dataclass
class Trajectory:
list[np.ndarray] # List of RGB frames
observations: list[dict] # Featurized representation of the state
info_dicts: # (not officially part of crafter)
@property
def timestep_n(self) -> int:
return len(self.info_dicts)
Approximating a predicate for drinking water
Aligning goals in crafter
In this example we will optimize simple models that detect if a goal has been achieved in a dynamic environment. The goal has been chosen in advance, and we have a ground truth representation of it, which allows us to get synthetic data on which to optimize the model. The objective of this analysis is to get an idea of the resources needed to align simple predicates in a simple simulated environment.
Environment
We will use the crafter environment. The environment is like a simplified 2D version of Minecraft. The environment consists of an agent that has to survive and accomplish “achievements”. There are a few items that can be crafted. The achievements are defined by the authors of the environment.
Trajectories
For simplicity and for the purposes of this document, this is what we define as a trajectory:
Code
import gym
import crafter
import plotting
from pathlib import Path
= Path("output_dir")
output_dir =False)
output_dir.mkdir(exist_ok
= 15
FPS
def sample_random_trajectory() -> Trajectory:
# Instantiate environment
= gym.make('CrafterReward-v1', apply_api_compatibility=True)
env
# Rollout a random trajectories
= env.reset()
obs = False
done = list()
observations = list()
info_dicts while not done:
= env.action_space.sample()
action = env.step(action)
obs, reward, done, truncated, info
observations.append(obs)
info_dicts.append(info)if done:
= env.reset()
obs
env.close()return Trajectory(observations=observations, info_dicts=info_dicts)
# Sample some trajectories and extract observations
= 2
trajectory_n = [
trajectories
sample_random_trajectory()for _ in range(trajectory_n)
]
= [
frames
observationfor trajectory in trajectories
for observation in trajectory.observations
]
# Plot trajectories
= output_dir/"demo_animation.webm"
animation_file
plotting.plot_animation(
frames,
FPS,
animation_file,
)
# Display trajectories
from IPython.display import display, Markdown
display(Markdown(f"Here are {trajectory_n} random trajectories in a single video:"
))"{{< video " + str(animation_file) + " >}}")) display(Markdown(
Here are 2 random trajectories in a single video:
Let us sample a few random trajectories with the goal of computing some statistics and inspect trajectories.
Code
# Time sampling process
import time
= time.time()
start_t
# Sample trajectories
= 1000
sample_n = [
sampled_trajectories
sample_random_trajectory()for _ in range(sample_n)
]
# Display running time
= time.time()
final_t = final_t - start_t
time_s
display(Markdown(f"Sequentially sampling {sample_n} trajectories (i.e., acting until the environment returns 'done' because the character died) took {time_s} seconds."
))
Sequentially sampling 1000 trajectories (i.e., acting until the environment returns ‘done’ because the character died) took 350.0899295806885 seconds.
Predicates
Our goal is to get an estimate of the computation budget necessary to approximate predicates in Crafter. Our final goal is to infer predicates from demonstrations. For example, we imagine collecting a dataset of labeled trajectories where the agent “drinks water”.
We define “predicate” as any function \(P : \text{Trajectory} \to \{0, 1\}\).
For now, we will solve the simplest formulation of this problem, where we have collected a dataset where every trajectory has been labeled as “SAT” or “UNSAT”, and we wish to find a function which approximates this decision boundary. This simple formulation will probably change in the future.
Case study: drinking water
While crafter is a challenge in which agents are supposed to act based on the observed image, for simplicity and only temporarily we first focus on approximating the predicate using the “info” dictionary. This is a featurized representation of the environment state.
Since in this analysis we are only interested in the computational cost needed to approximate this predicate, we will define a ground truth predicate in order to remove the need for user data. The predicate under consideration is “the agent drank water in the trajectory”:
def has_drunk_water(trajectory: Trajectory) -> bool:
"""Return `True` if and only if the agent drank water in the given
`trajectory`."""
= trajectory.info_dicts
features for t in range(1, trajectory.timestep_n):
= features[t-1]["inventory"]["drink"]
previous_drink = features[ t ]["inventory"]["drink"]
current_drink if previous_drink < current_drink:
return True
return False
This is as simple as it gets.
Note that the predicate is true if the agent drank water at any moment during the trajectory, not necessarily at the end. We now classify the sampled trajectories using the ground-truth predicate.
Code
= list()
sat_trajectories = list()
unsat_trajectories for trajectory in sampled_trajectories:
= sample_random_trajectory()
trajectory if has_drunk_water(trajectory):
sat_trajectories.append(trajectory)else:
unsat_trajectories.append(trajectory)
assert len(sat_trajectories) > 0, "We should not perform the analysis with zero SAT trajectories"
assert len(unsat_trajectories) > 0, "We should not perform the analysis with zero UNSAT trajectories"
We sampled 92 SAT trajectories and 908 UNSAT trajectories.
Here is a random trajectory that satisfies the predicate:
Code
# Obtain the frames of a SAT trajectory
= sat_trajectories[0].observations
frames
# Plot trajectories
= output_dir/"sat_trajectory.webm"
animation_file
plotting.plot_animation(
frames,
FPS,
animation_file,
)
# Display trajectory
"{{< video " + str(animation_file) + " >}}")) display(Markdown(
Here is a random trajectory that does not satisfy the predicate:
Code
# Obtain the frames of an UNSAT trajectory
= unsat_trajectories[0].observations
frames
# Plot trajectories
= output_dir/"unsat_trajectory.webm"
animation_file
plotting.plot_animation(
frames,
FPS,
animation_file,
)
# Display trajectory
"{{< video " + str(animation_file) + " >}}")) display(Markdown(
Featurization
Let us featurize the trajectories and train a classifier. This is the hacky part of the experiment, as it encodes a lot of domain knowledge. This is an important part of the experiment.
The domain knowledge that we will encode are the following two key observations:
- we only care about the value of the “drink” state, and
- we only care about the window leading to the precise moment the predicate becomes true (and we know the length of the window).
We thus featurize trajectories into fixed-sized lists of numbers by extracting the values of the “drink” value in the last timesteps that lead to the predicate changing truth value, or a random time window if the predicate is false:
import random
= 5
window_size
def get_features(trajectory: Trajectory, window_size: int) -> list[float]:
"""Transform the trajectory into a fixed-size list of numbers."""
# This list will contain "moving-window" subtrajectories of fixed size
= list()
subtrajectories
# Slide a time-window across the trajectory
= trajectory.info_dicts
info_dicts for t in range(window_size, trajectory.timestep_n):
= Trajectory(
subtrajectory =trajectory.info_dicts[t-window_size:t],
info_dicts=trajectory.observations[t-window_size:t],
observations
)
subtrajectories.append(subtrajectory)
# Classify subtrajectories
= list()
sat_subtrajectories = list()
unsat_subtrajectories for subtrajectory in subtrajectories:
if has_drunk_water(subtrajectory):
sat_subtrajectories.append(subtrajectory)else:
unsat_subtrajectories.append(subtrajectory)
# Choose the first subtrajectory that satisfies the predicate, else, choose
# one that does not satisfy the predicate.
# This guarantees that we will use the trajectory that contains the instant
# where the predicate becomes true
if len(sat_subtrajectories) > 0:
= random.choice(sat_subtrajectories)
subtrajectory else:
= random.choice(unsat_subtrajectories)
subtrajectory
# Transform the subtrajectory into a list of numbers
= [
features "inventory"]["drink"]
info_dict[for info_dict in subtrajectory.info_dicts
]return features
Here is a featurized version of a random SAT trajectory:
Code
print(get_features(sat_trajectories[0], window_size))
[0, 1, 1, 1, 1]
Here is a featurized version of a random UNSAT trajectory:
Code
print(get_features(unsat_trajectories[0], window_size))
[8, 8, 8, 7, 7]
Models
We now optimize a simple model to approximate the predicate. That is, we will execute an algorithm that searches for a function that has a low error rate within our labeled dataset. We thus need to decide on:
- a search space consisting of functions \(\text{Trajectory} \to \{0, 1\}\),
- an algorithm to search in said space.
We also need to transform our list of trajectories into a regression problem, where we want to map each featurized trajectory to \(1\) if the trajectory is SAT and \(0\) if the trajectory is UNSAT.
Code
= list()
X = list()
y for trajectory in sat_trajectories:
X.append(get_features(trajectory, window_size))1)
y.append(for trajectory in unsat_trajectories:
X.append(get_features(trajectory, window_size))0) y.append(
A separate set of trajectories will be used to estimate the performance of the models we train.
Code
= list()
X_test = list()
y_test = 1000
test_sample_n for _ in range(test_sample_n):
= sample_random_trajectory()
trajectory
X_test.append(get_features(trajectory, window_size))if has_drunk_water(trajectory):
1)
y_test.append(else:
0) y_test.append(
We sampled 1000 trajectories, of which there are 97 SAT instances and 903 UNSAT instances.
Decision tree
Let us first consider one of the simplest models: decision trees. We will use the sklearn library, which uses “an optimized version of the CART algorithm”.
Code
from sklearn.tree import DecisionTreeClassifier
import time
# Time the algorithm
= time.time()
start_t
# Execute algorithm
= DecisionTreeClassifier()
clf
clf.fit(X, y)
# Display execution time
= time.time()
end_t = end_t - start_t
time_s f"The algorithm took {time_s} to run.")) display(Markdown(
The algorithm took 0.004048585891723633 to run.
This is the tree that was found:
Code
import matplotlib.pyplot as plt
from sklearn import tree
= plt.subplots(nrows = 1,ncols = 1,figsize = (4,4), dpi=300)
fig, axes
= ["UNSAT", "SAT"]
class_names
tree.plot_tree(
clf,=class_names,
class_names= True
filled )
[Text(0.7740384615384616, 0.9615384615384616, 'x[0] <= 5.5\ngini = 0.167\nsamples = 1000\nvalue = [908, 92]\nclass = UNSAT'),
Text(0.625, 0.8846153846153846, 'x[4] <= 5.5\ngini = 0.251\nsamples = 448\nvalue = [382, 66]\nclass = UNSAT'),
Text(0.5865384615384616, 0.8076923076923077, 'x[0] <= 4.5\ngini = 0.202\nsamples = 431\nvalue = [382, 49]\nclass = UNSAT'),
Text(0.4807692307692308, 0.7307692307692307, 'x[4] <= 4.5\ngini = 0.272\nsamples = 296\nvalue = [248, 48]\nclass = UNSAT'),
Text(0.4423076923076923, 0.6538461538461539, 'x[0] <= 3.5\ngini = 0.182\nsamples = 276\nvalue = [248, 28]\nclass = UNSAT'),
Text(0.34615384615384615, 0.5769230769230769, 'x[4] <= 3.5\ngini = 0.271\nsamples = 167\nvalue = [140, 27]\nclass = UNSAT'),
Text(0.3076923076923077, 0.5, 'x[0] <= 2.5\ngini = 0.184\nsamples = 156\nvalue = [140, 16]\nclass = UNSAT'),
Text(0.23076923076923078, 0.4230769230769231, 'x[4] <= 2.5\ngini = 0.296\nsamples = 83\nvalue = [68, 15]\nclass = UNSAT'),
Text(0.19230769230769232, 0.34615384615384615, 'x[0] <= 1.5\ngini = 0.206\nsamples = 77\nvalue = [68, 9]\nclass = UNSAT'),
Text(0.15384615384615385, 0.2692307692307692, 'x[4] <= 1.5\ngini = 0.325\nsamples = 44\nvalue = [35, 9]\nclass = UNSAT'),
Text(0.11538461538461539, 0.19230769230769232, 'x[1] <= 0.5\ngini = 0.184\nsamples = 39\nvalue = [35, 4]\nclass = UNSAT'),
Text(0.07692307692307693, 0.11538461538461539, 'x[4] <= 0.5\ngini = 0.391\nsamples = 15\nvalue = [11, 4]\nclass = UNSAT'),
Text(0.038461538461538464, 0.038461538461538464, 'gini = 0.0\nsamples = 11\nvalue = [11, 0]\nclass = UNSAT'),
Text(0.11538461538461539, 0.038461538461538464, 'gini = 0.0\nsamples = 4\nvalue = [0, 4]\nclass = SAT'),
Text(0.15384615384615385, 0.11538461538461539, 'gini = 0.0\nsamples = 24\nvalue = [24, 0]\nclass = UNSAT'),
Text(0.19230769230769232, 0.19230769230769232, 'gini = 0.0\nsamples = 5\nvalue = [0, 5]\nclass = SAT'),
Text(0.23076923076923078, 0.2692307692307692, 'gini = 0.0\nsamples = 33\nvalue = [33, 0]\nclass = UNSAT'),
Text(0.2692307692307692, 0.34615384615384615, 'gini = 0.0\nsamples = 6\nvalue = [0, 6]\nclass = SAT'),
Text(0.38461538461538464, 0.4230769230769231, 'x[1] <= 2.5\ngini = 0.027\nsamples = 73\nvalue = [72, 1]\nclass = UNSAT'),
Text(0.34615384615384615, 0.34615384615384615, 'x[4] <= 2.5\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = UNSAT'),
Text(0.3076923076923077, 0.2692307692307692, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = UNSAT'),
Text(0.38461538461538464, 0.2692307692307692, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = SAT'),
Text(0.4230769230769231, 0.34615384615384615, 'gini = 0.0\nsamples = 71\nvalue = [71, 0]\nclass = UNSAT'),
Text(0.38461538461538464, 0.5, 'gini = 0.0\nsamples = 11\nvalue = [0, 11]\nclass = SAT'),
Text(0.5384615384615384, 0.5769230769230769, 'x[2] <= 3.5\ngini = 0.018\nsamples = 109\nvalue = [108, 1]\nclass = UNSAT'),
Text(0.5, 0.5, 'x[3] <= 3.5\ngini = 0.32\nsamples = 5\nvalue = [4, 1]\nclass = UNSAT'),
Text(0.46153846153846156, 0.4230769230769231, 'gini = 0.0\nsamples = 4\nvalue = [4, 0]\nclass = UNSAT'),
Text(0.5384615384615384, 0.4230769230769231, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = SAT'),
Text(0.5769230769230769, 0.5, 'gini = 0.0\nsamples = 104\nvalue = [104, 0]\nclass = UNSAT'),
Text(0.5192307692307693, 0.6538461538461539, 'gini = 0.0\nsamples = 20\nvalue = [0, 20]\nclass = SAT'),
Text(0.6923076923076923, 0.7307692307692307, 'x[2] <= 4.5\ngini = 0.015\nsamples = 135\nvalue = [134, 1]\nclass = UNSAT'),
Text(0.6538461538461539, 0.6538461538461539, 'x[4] <= 4.5\ngini = 0.18\nsamples = 10\nvalue = [9, 1]\nclass = UNSAT'),
Text(0.6153846153846154, 0.5769230769230769, 'gini = 0.0\nsamples = 9\nvalue = [9, 0]\nclass = UNSAT'),
Text(0.6923076923076923, 0.5769230769230769, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = SAT'),
Text(0.7307692307692307, 0.6538461538461539, 'gini = 0.0\nsamples = 125\nvalue = [125, 0]\nclass = UNSAT'),
Text(0.6634615384615384, 0.8076923076923077, 'gini = 0.0\nsamples = 17\nvalue = [0, 17]\nclass = SAT'),
Text(0.9230769230769231, 0.8846153846153846, 'x[0] <= 8.5\ngini = 0.09\nsamples = 552\nvalue = [526, 26]\nclass = UNSAT'),
Text(0.8846153846153846, 0.8076923076923077, 'x[4] <= 8.5\ngini = 0.116\nsamples = 420\nvalue = [394, 26]\nclass = UNSAT'),
Text(0.8461538461538461, 0.7307692307692307, 'x[0] <= 7.5\ngini = 0.088\nsamples = 413\nvalue = [394, 19]\nclass = UNSAT'),
Text(0.8076923076923077, 0.6538461538461539, 'x[4] <= 7.5\ngini = 0.13\nsamples = 272\nvalue = [253, 19]\nclass = UNSAT'),
Text(0.7692307692307693, 0.5769230769230769, 'x[4] <= 6.5\ngini = 0.086\nsamples = 265\nvalue = [253, 12]\nclass = UNSAT'),
Text(0.6923076923076923, 0.5, 'x[1] <= 5.5\ngini = 0.013\nsamples = 148\nvalue = [147, 1]\nclass = UNSAT'),
Text(0.6538461538461539, 0.4230769230769231, 'x[3] <= 5.5\ngini = 0.375\nsamples = 4\nvalue = [3, 1]\nclass = UNSAT'),
Text(0.6153846153846154, 0.34615384615384615, 'gini = 0.0\nsamples = 3\nvalue = [3, 0]\nclass = UNSAT'),
Text(0.6923076923076923, 0.34615384615384615, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = SAT'),
Text(0.7307692307692307, 0.4230769230769231, 'gini = 0.0\nsamples = 144\nvalue = [144, 0]\nclass = UNSAT'),
Text(0.8461538461538461, 0.5, 'x[0] <= 6.5\ngini = 0.17\nsamples = 117\nvalue = [106, 11]\nclass = UNSAT'),
Text(0.8076923076923077, 0.4230769230769231, 'gini = 0.0\nsamples = 10\nvalue = [0, 10]\nclass = SAT'),
Text(0.8846153846153846, 0.4230769230769231, 'x[2] <= 6.5\ngini = 0.019\nsamples = 107\nvalue = [106, 1]\nclass = UNSAT'),
Text(0.8461538461538461, 0.34615384615384615, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = SAT'),
Text(0.9230769230769231, 0.34615384615384615, 'gini = 0.0\nsamples = 106\nvalue = [106, 0]\nclass = UNSAT'),
Text(0.8461538461538461, 0.5769230769230769, 'gini = 0.0\nsamples = 7\nvalue = [0, 7]\nclass = SAT'),
Text(0.8846153846153846, 0.6538461538461539, 'gini = 0.0\nsamples = 141\nvalue = [141, 0]\nclass = UNSAT'),
Text(0.9230769230769231, 0.7307692307692307, 'gini = 0.0\nsamples = 7\nvalue = [0, 7]\nclass = SAT'),
Text(0.9615384615384616, 0.8076923076923077, 'gini = 0.0\nsamples = 132\nvalue = [132, 0]\nclass = UNSAT')]
Code
= clf.score(X_test, y_test)
test_accuracy
display(Markdown(f"The algorithm had a mean accuracy of {test_accuracy} in the test data."
))
The algorithm had a mean accuracy of 0.993 in the test data.
Multi-layer perceptron
We will also use sklearn MLP implementation, which runs the traditional backpropagation algorithm.
Because gradient-based optimization is brittle in many ways, the standard approach is to normalize the features. In this case we normalize the feature range to \([0, 1]\) using statistics from the train set. There is also class imbalance, which we will not address for now.
Code
= max([xi for xs in X for xi in xs])
max_feature = min([xi for xs in X for xi in xs])
min_feature
def get_normalized_for_mlp(features: list[float]) -> list[float]:
return [
-min_feature)/(max_feature-min_feature)
(xifor xi in features
]= [get_normalized_for_mlp(x) for x in X]
X_mlp = [get_normalized_for_mlp(x) for x in X_test] X_test_mlp
Here is a normalized version of the features of a random trajectory:
Code
print(X_mlp[0])
[0.1111111111111111, 0.1111111111111111, 0.2222222222222222, 0.2222222222222222, 0.2222222222222222]
We then execute the backpropagation algorithm on the normalized features to find an approximation to the predicate.
Code
from sklearn.neural_network import MLPClassifier
# Time the algorithm
= time.time()
start_t
# Execute algorithm
= MLPClassifier(
clf =(32, 32),
hidden_layer_sizes
)
clf.fit(X_mlp, y)
# Display execution time
= time.time()
end_t = end_t - start_t
time_s f"The algorithm took {time_s} to run.")) display(Markdown(
/home/leonardohernandezcano/.cache/pypoetry/virtualenvs/aligning_goals_in_crafter-xwsAhvRX-py3.10/lib/python3.10/site-packages/sklearn/neural_network/_multilayer_perceptron.py:684: ConvergenceWarning:
Stochastic Optimizer: Maximum iterations (200) reached and the optimization hasn't converged yet.
The algorithm took 0.2840995788574219 to run.
Code
= clf.score(X_test_mlp, y_test)
test_accuracy
display(Markdown(f"The algorithm had a mean accuracy of {test_accuracy} in the test data."
))
The algorithm had a mean accuracy of 0.99 in the test data.
Next steps
By far the biggest limitation of this experiment is that we have some fixed labels, whereas we wish to allow users to provide instructions in free-form natural language. This can be addressed in at least two ways:
- Design an algorithm to manipulate the predicate space dynamically, deciding if a new class should be added.
- Map user utterances directly to predicates. Either explicitly (as an utterance to predicate code model) or implicitly (zero-shot prompting).