Approximating a predicate for drinking water

Aligning goals in crafter

In this example we will optimize simple models that detect if a goal has been achieved in a dynamic environment. The goal has been chosen in advance, and we have a ground truth representation of it, which allows us to get synthetic data on which to optimize the model. The objective of this analysis is to get an idea of the resources needed to align simple predicates in a simple simulated environment.

Environment

We will use the crafter environment. The environment is like a simplified 2D version of Minecraft. The environment consists of an agent that has to survive and accomplish “achievements”. There are a few items that can be crafted. The achievements are defined by the authors of the environment.

Trajectories

For simplicity and for the purposes of this document, this is what we define as a trajectory:

from dataclasses import dataclass
import numpy as np

@dataclass
class Trajectory:
  observations: list[np.ndarray]  # List of RGB frames
  info_dicts: list[dict]  # Featurized representation of the state
                          # (not officially part of crafter)

  @property
  def timestep_n(self) -> int:
    return len(self.info_dicts)
Code
import gym
import crafter
import plotting
from pathlib import Path

output_dir = Path("output_dir")
output_dir.mkdir(exist_ok=False)

FPS = 15

def sample_random_trajectory() -> Trajectory:
  # Instantiate environment
  env = gym.make('CrafterReward-v1', apply_api_compatibility=True)

  # Rollout a random trajectories
  obs = env.reset()
  done = False
  observations = list()
  info_dicts = list()
  while not done:
    action = env.action_space.sample()
    obs, reward, done, truncated, info = env.step(action)
    observations.append(obs)
    info_dicts.append(info)
    if done:
      obs = env.reset()
  env.close()
  return Trajectory(observations=observations, info_dicts=info_dicts)

# Sample some trajectories and extract observations
trajectory_n = 2
trajectories = [
  sample_random_trajectory()
  for _ in range(trajectory_n)
]

frames = [
  observation
  for trajectory in trajectories
  for observation in trajectory.observations
]

# Plot trajectories
animation_file = output_dir/"demo_animation.webm"
plotting.plot_animation(
  frames,
  FPS,
  animation_file,
  )

# Display trajectories
from IPython.display import display, Markdown
display(Markdown(
  f"Here are {trajectory_n} random trajectories in a single video:"
))
display(Markdown("{{< video " + str(animation_file) + " >}}"))

Here are 2 random trajectories in a single video:

Let us sample a few random trajectories with the goal of computing some statistics and inspect trajectories.

Code
# Time sampling process
import time
start_t = time.time()

# Sample trajectories
sample_n = 1000
sampled_trajectories = [
  sample_random_trajectory()
  for _ in range(sample_n)
]

# Display running time
final_t = time.time()
time_s = final_t - start_t
display(Markdown(
  f"Sequentially sampling {sample_n} trajectories (i.e., acting until the environment returns 'done' because the character died) took {time_s} seconds."
))

Sequentially sampling 1000 trajectories (i.e., acting until the environment returns ‘done’ because the character died) took 350.0899295806885 seconds.

Predicates

Our goal is to get an estimate of the computation budget necessary to approximate predicates in Crafter. Our final goal is to infer predicates from demonstrations. For example, we imagine collecting a dataset of labeled trajectories where the agent “drinks water”.

We define “predicate” as any function \(P : \text{Trajectory} \to \{0, 1\}\).

For now, we will solve the simplest formulation of this problem, where we have collected a dataset where every trajectory has been labeled as “SAT” or “UNSAT”, and we wish to find a function which approximates this decision boundary. This simple formulation will probably change in the future.

Case study: drinking water

While crafter is a challenge in which agents are supposed to act based on the observed image, for simplicity and only temporarily we first focus on approximating the predicate using the “info” dictionary. This is a featurized representation of the environment state.

Since in this analysis we are only interested in the computational cost needed to approximate this predicate, we will define a ground truth predicate in order to remove the need for user data. The predicate under consideration is “the agent drank water in the trajectory”:

def has_drunk_water(trajectory: Trajectory) -> bool:
  """Return `True` if and only if the agent drank water in the given
  `trajectory`."""
  features = trajectory.info_dicts
  for t in range(1, trajectory.timestep_n):
    previous_drink = features[t-1]["inventory"]["drink"]
    current_drink  = features[ t ]["inventory"]["drink"]
    if previous_drink < current_drink:
      return True
  return False

This is as simple as it gets.

Note that the predicate is true if the agent drank water at any moment during the trajectory, not necessarily at the end. We now classify the sampled trajectories using the ground-truth predicate.

Code
sat_trajectories = list()
unsat_trajectories = list()
for trajectory in sampled_trajectories:
  trajectory = sample_random_trajectory()
  if has_drunk_water(trajectory):
    sat_trajectories.append(trajectory)
  else:
    unsat_trajectories.append(trajectory)

assert len(sat_trajectories) > 0, "We should not perform the analysis with zero SAT trajectories"
assert len(unsat_trajectories) > 0, "We should not perform the analysis with zero UNSAT trajectories"

We sampled 92 SAT trajectories and 908 UNSAT trajectories.

Here is a random trajectory that satisfies the predicate:

Code
# Obtain the frames of a SAT trajectory
frames = sat_trajectories[0].observations

# Plot trajectories
animation_file = output_dir/"sat_trajectory.webm"
plotting.plot_animation(
  frames,
  FPS,
  animation_file,
)

# Display trajectory
display(Markdown("{{< video " + str(animation_file) + " >}}"))

Here is a random trajectory that does not satisfy the predicate:

Code
# Obtain the frames of an UNSAT trajectory
frames = unsat_trajectories[0].observations

# Plot trajectories
animation_file = output_dir/"unsat_trajectory.webm"
plotting.plot_animation(
  frames,
  FPS,
  animation_file,
)

# Display trajectory
display(Markdown("{{< video " + str(animation_file) + " >}}"))

Featurization

Let us featurize the trajectories and train a classifier. This is the hacky part of the experiment, as it encodes a lot of domain knowledge. This is an important part of the experiment.

The domain knowledge that we will encode are the following two key observations:

  • we only care about the value of the “drink” state, and
  • we only care about the window leading to the precise moment the predicate becomes true (and we know the length of the window).

We thus featurize trajectories into fixed-sized lists of numbers by extracting the values of the “drink” value in the last timesteps that lead to the predicate changing truth value, or a random time window if the predicate is false:

import random

window_size = 5

def get_features(trajectory: Trajectory, window_size: int) -> list[float]:
  """Transform the trajectory into a fixed-size list of numbers."""
  # This list will contain "moving-window" subtrajectories of fixed size
  subtrajectories = list()

  # Slide a time-window across the trajectory
  info_dicts = trajectory.info_dicts
  for t in range(window_size, trajectory.timestep_n):
    subtrajectory = Trajectory(
      info_dicts=trajectory.info_dicts[t-window_size:t],
      observations=trajectory.observations[t-window_size:t],
    )
    subtrajectories.append(subtrajectory)

  # Classify subtrajectories
  sat_subtrajectories = list()
  unsat_subtrajectories = list()
  for subtrajectory in subtrajectories:
    if has_drunk_water(subtrajectory):
      sat_subtrajectories.append(subtrajectory)
    else:
      unsat_subtrajectories.append(subtrajectory)

  # Choose the first subtrajectory that satisfies the predicate, else, choose
  # one that does not satisfy the predicate.
  # This guarantees that we will use the trajectory that contains the instant
  # where the predicate becomes true
  if len(sat_subtrajectories) > 0:
    subtrajectory = random.choice(sat_subtrajectories)
  else:
    subtrajectory = random.choice(unsat_subtrajectories)

  # Transform the subtrajectory into a list of numbers
  features = [
    info_dict["inventory"]["drink"]
    for info_dict in subtrajectory.info_dicts
  ]
  return features

Here is a featurized version of a random SAT trajectory:

Code
print(get_features(sat_trajectories[0], window_size))
[0, 1, 1, 1, 1]

Here is a featurized version of a random UNSAT trajectory:

Code
print(get_features(unsat_trajectories[0], window_size))
[8, 8, 8, 7, 7]

Models

We now optimize a simple model to approximate the predicate. That is, we will execute an algorithm that searches for a function that has a low error rate within our labeled dataset. We thus need to decide on:

  • a search space consisting of functions \(\text{Trajectory} \to \{0, 1\}\),
  • an algorithm to search in said space.

We also need to transform our list of trajectories into a regression problem, where we want to map each featurized trajectory to \(1\) if the trajectory is SAT and \(0\) if the trajectory is UNSAT.

Code
X = list()
y = list()
for trajectory in sat_trajectories:
  X.append(get_features(trajectory, window_size))
  y.append(1)
for trajectory in unsat_trajectories:
  X.append(get_features(trajectory, window_size))
  y.append(0)

A separate set of trajectories will be used to estimate the performance of the models we train.

Code
X_test = list()
y_test = list()
test_sample_n = 1000
for _ in range(test_sample_n):
  trajectory = sample_random_trajectory()
  X_test.append(get_features(trajectory, window_size))
  if has_drunk_water(trajectory):
    y_test.append(1)
  else:
    y_test.append(0)

We sampled 1000 trajectories, of which there are 97 SAT instances and 903 UNSAT instances.

Decision tree

Let us first consider one of the simplest models: decision trees. We will use the sklearn library, which uses “an optimized version of the CART algorithm”.

Code
from sklearn.tree import DecisionTreeClassifier
import time

# Time the algorithm
start_t = time.time()

# Execute algorithm
clf = DecisionTreeClassifier()
clf.fit(X, y)

# Display execution time
end_t = time.time()
time_s = end_t - start_t
display(Markdown(f"The algorithm took {time_s} to run."))

The algorithm took 0.004048585891723633 to run.

This is the tree that was found:

Code
import matplotlib.pyplot as plt
from sklearn import tree

fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (4,4), dpi=300)

class_names = ["UNSAT", "SAT"]
tree.plot_tree(
  clf,
  class_names=class_names,
  filled = True
)
[Text(0.7740384615384616, 0.9615384615384616, 'x[0] <= 5.5\ngini = 0.167\nsamples = 1000\nvalue = [908, 92]\nclass = UNSAT'),
 Text(0.625, 0.8846153846153846, 'x[4] <= 5.5\ngini = 0.251\nsamples = 448\nvalue = [382, 66]\nclass = UNSAT'),
 Text(0.5865384615384616, 0.8076923076923077, 'x[0] <= 4.5\ngini = 0.202\nsamples = 431\nvalue = [382, 49]\nclass = UNSAT'),
 Text(0.4807692307692308, 0.7307692307692307, 'x[4] <= 4.5\ngini = 0.272\nsamples = 296\nvalue = [248, 48]\nclass = UNSAT'),
 Text(0.4423076923076923, 0.6538461538461539, 'x[0] <= 3.5\ngini = 0.182\nsamples = 276\nvalue = [248, 28]\nclass = UNSAT'),
 Text(0.34615384615384615, 0.5769230769230769, 'x[4] <= 3.5\ngini = 0.271\nsamples = 167\nvalue = [140, 27]\nclass = UNSAT'),
 Text(0.3076923076923077, 0.5, 'x[0] <= 2.5\ngini = 0.184\nsamples = 156\nvalue = [140, 16]\nclass = UNSAT'),
 Text(0.23076923076923078, 0.4230769230769231, 'x[4] <= 2.5\ngini = 0.296\nsamples = 83\nvalue = [68, 15]\nclass = UNSAT'),
 Text(0.19230769230769232, 0.34615384615384615, 'x[0] <= 1.5\ngini = 0.206\nsamples = 77\nvalue = [68, 9]\nclass = UNSAT'),
 Text(0.15384615384615385, 0.2692307692307692, 'x[4] <= 1.5\ngini = 0.325\nsamples = 44\nvalue = [35, 9]\nclass = UNSAT'),
 Text(0.11538461538461539, 0.19230769230769232, 'x[1] <= 0.5\ngini = 0.184\nsamples = 39\nvalue = [35, 4]\nclass = UNSAT'),
 Text(0.07692307692307693, 0.11538461538461539, 'x[4] <= 0.5\ngini = 0.391\nsamples = 15\nvalue = [11, 4]\nclass = UNSAT'),
 Text(0.038461538461538464, 0.038461538461538464, 'gini = 0.0\nsamples = 11\nvalue = [11, 0]\nclass = UNSAT'),
 Text(0.11538461538461539, 0.038461538461538464, 'gini = 0.0\nsamples = 4\nvalue = [0, 4]\nclass = SAT'),
 Text(0.15384615384615385, 0.11538461538461539, 'gini = 0.0\nsamples = 24\nvalue = [24, 0]\nclass = UNSAT'),
 Text(0.19230769230769232, 0.19230769230769232, 'gini = 0.0\nsamples = 5\nvalue = [0, 5]\nclass = SAT'),
 Text(0.23076923076923078, 0.2692307692307692, 'gini = 0.0\nsamples = 33\nvalue = [33, 0]\nclass = UNSAT'),
 Text(0.2692307692307692, 0.34615384615384615, 'gini = 0.0\nsamples = 6\nvalue = [0, 6]\nclass = SAT'),
 Text(0.38461538461538464, 0.4230769230769231, 'x[1] <= 2.5\ngini = 0.027\nsamples = 73\nvalue = [72, 1]\nclass = UNSAT'),
 Text(0.34615384615384615, 0.34615384615384615, 'x[4] <= 2.5\ngini = 0.5\nsamples = 2\nvalue = [1, 1]\nclass = UNSAT'),
 Text(0.3076923076923077, 0.2692307692307692, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]\nclass = UNSAT'),
 Text(0.38461538461538464, 0.2692307692307692, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = SAT'),
 Text(0.4230769230769231, 0.34615384615384615, 'gini = 0.0\nsamples = 71\nvalue = [71, 0]\nclass = UNSAT'),
 Text(0.38461538461538464, 0.5, 'gini = 0.0\nsamples = 11\nvalue = [0, 11]\nclass = SAT'),
 Text(0.5384615384615384, 0.5769230769230769, 'x[2] <= 3.5\ngini = 0.018\nsamples = 109\nvalue = [108, 1]\nclass = UNSAT'),
 Text(0.5, 0.5, 'x[3] <= 3.5\ngini = 0.32\nsamples = 5\nvalue = [4, 1]\nclass = UNSAT'),
 Text(0.46153846153846156, 0.4230769230769231, 'gini = 0.0\nsamples = 4\nvalue = [4, 0]\nclass = UNSAT'),
 Text(0.5384615384615384, 0.4230769230769231, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = SAT'),
 Text(0.5769230769230769, 0.5, 'gini = 0.0\nsamples = 104\nvalue = [104, 0]\nclass = UNSAT'),
 Text(0.5192307692307693, 0.6538461538461539, 'gini = 0.0\nsamples = 20\nvalue = [0, 20]\nclass = SAT'),
 Text(0.6923076923076923, 0.7307692307692307, 'x[2] <= 4.5\ngini = 0.015\nsamples = 135\nvalue = [134, 1]\nclass = UNSAT'),
 Text(0.6538461538461539, 0.6538461538461539, 'x[4] <= 4.5\ngini = 0.18\nsamples = 10\nvalue = [9, 1]\nclass = UNSAT'),
 Text(0.6153846153846154, 0.5769230769230769, 'gini = 0.0\nsamples = 9\nvalue = [9, 0]\nclass = UNSAT'),
 Text(0.6923076923076923, 0.5769230769230769, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = SAT'),
 Text(0.7307692307692307, 0.6538461538461539, 'gini = 0.0\nsamples = 125\nvalue = [125, 0]\nclass = UNSAT'),
 Text(0.6634615384615384, 0.8076923076923077, 'gini = 0.0\nsamples = 17\nvalue = [0, 17]\nclass = SAT'),
 Text(0.9230769230769231, 0.8846153846153846, 'x[0] <= 8.5\ngini = 0.09\nsamples = 552\nvalue = [526, 26]\nclass = UNSAT'),
 Text(0.8846153846153846, 0.8076923076923077, 'x[4] <= 8.5\ngini = 0.116\nsamples = 420\nvalue = [394, 26]\nclass = UNSAT'),
 Text(0.8461538461538461, 0.7307692307692307, 'x[0] <= 7.5\ngini = 0.088\nsamples = 413\nvalue = [394, 19]\nclass = UNSAT'),
 Text(0.8076923076923077, 0.6538461538461539, 'x[4] <= 7.5\ngini = 0.13\nsamples = 272\nvalue = [253, 19]\nclass = UNSAT'),
 Text(0.7692307692307693, 0.5769230769230769, 'x[4] <= 6.5\ngini = 0.086\nsamples = 265\nvalue = [253, 12]\nclass = UNSAT'),
 Text(0.6923076923076923, 0.5, 'x[1] <= 5.5\ngini = 0.013\nsamples = 148\nvalue = [147, 1]\nclass = UNSAT'),
 Text(0.6538461538461539, 0.4230769230769231, 'x[3] <= 5.5\ngini = 0.375\nsamples = 4\nvalue = [3, 1]\nclass = UNSAT'),
 Text(0.6153846153846154, 0.34615384615384615, 'gini = 0.0\nsamples = 3\nvalue = [3, 0]\nclass = UNSAT'),
 Text(0.6923076923076923, 0.34615384615384615, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = SAT'),
 Text(0.7307692307692307, 0.4230769230769231, 'gini = 0.0\nsamples = 144\nvalue = [144, 0]\nclass = UNSAT'),
 Text(0.8461538461538461, 0.5, 'x[0] <= 6.5\ngini = 0.17\nsamples = 117\nvalue = [106, 11]\nclass = UNSAT'),
 Text(0.8076923076923077, 0.4230769230769231, 'gini = 0.0\nsamples = 10\nvalue = [0, 10]\nclass = SAT'),
 Text(0.8846153846153846, 0.4230769230769231, 'x[2] <= 6.5\ngini = 0.019\nsamples = 107\nvalue = [106, 1]\nclass = UNSAT'),
 Text(0.8461538461538461, 0.34615384615384615, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]\nclass = SAT'),
 Text(0.9230769230769231, 0.34615384615384615, 'gini = 0.0\nsamples = 106\nvalue = [106, 0]\nclass = UNSAT'),
 Text(0.8461538461538461, 0.5769230769230769, 'gini = 0.0\nsamples = 7\nvalue = [0, 7]\nclass = SAT'),
 Text(0.8846153846153846, 0.6538461538461539, 'gini = 0.0\nsamples = 141\nvalue = [141, 0]\nclass = UNSAT'),
 Text(0.9230769230769231, 0.7307692307692307, 'gini = 0.0\nsamples = 7\nvalue = [0, 7]\nclass = SAT'),
 Text(0.9615384615384616, 0.8076923076923077, 'gini = 0.0\nsamples = 132\nvalue = [132, 0]\nclass = UNSAT')]

Code
test_accuracy = clf.score(X_test, y_test)
display(Markdown(
  f"The algorithm had a mean accuracy of {test_accuracy} in the test data."
))

The algorithm had a mean accuracy of 0.993 in the test data.

Multi-layer perceptron

We will also use sklearn MLP implementation, which runs the traditional backpropagation algorithm.

Because gradient-based optimization is brittle in many ways, the standard approach is to normalize the features. In this case we normalize the feature range to \([0, 1]\) using statistics from the train set. There is also class imbalance, which we will not address for now.

Code
max_feature = max([xi for xs in X for xi in xs])
min_feature = min([xi for xs in X for xi in xs])

def get_normalized_for_mlp(features: list[float]) -> list[float]:
  return [
    (xi-min_feature)/(max_feature-min_feature)
    for xi in features
  ]
X_mlp = [get_normalized_for_mlp(x) for x in X]
X_test_mlp = [get_normalized_for_mlp(x) for x in X_test]

Here is a normalized version of the features of a random trajectory:

Code
print(X_mlp[0])
[0.1111111111111111, 0.1111111111111111, 0.2222222222222222, 0.2222222222222222, 0.2222222222222222]

We then execute the backpropagation algorithm on the normalized features to find an approximation to the predicate.

Code
from sklearn.neural_network import MLPClassifier

# Time the algorithm
start_t = time.time()

# Execute algorithm
clf = MLPClassifier(
  hidden_layer_sizes=(32, 32),
)
clf.fit(X_mlp, y)

# Display execution time
end_t = time.time()
time_s = end_t - start_t
display(Markdown(f"The algorithm took {time_s} to run."))
/home/leonardohernandezcano/.cache/pypoetry/virtualenvs/aligning_goals_in_crafter-xwsAhvRX-py3.10/lib/python3.10/site-packages/sklearn/neural_network/_multilayer_perceptron.py:684: ConvergenceWarning:

Stochastic Optimizer: Maximum iterations (200) reached and the optimization hasn't converged yet.

The algorithm took 0.2840995788574219 to run.

Code
test_accuracy = clf.score(X_test_mlp, y_test)
display(Markdown(
  f"The algorithm had a mean accuracy of {test_accuracy} in the test data."
))

The algorithm had a mean accuracy of 0.99 in the test data.

Next steps

By far the biggest limitation of this experiment is that we have some fixed labels, whereas we wish to allow users to provide instructions in free-form natural language. This can be addressed in at least two ways:

  • Design an algorithm to manipulate the predicate space dynamically, deciding if a new class should be added.
  • Map user utterances directly to predicates. Either explicitly (as an utterance to predicate code model) or implicitly (zero-shot prompting).