States#

Overview#

States are how we communicate with simulated environments. We provide two state types: TensorState and DictEnvState . TensorState is implemented with torch tensors as base data structure and designed for efficiency, while DictEnvState is implemented with dicts for more user-friendly interactions.

Tensor State#

TensorStates is a dict of vectorized states. At the top level, it contains 4 fields: objects, robots, cameras and extras. Each field then contains multiple sub-fields for data.

@dataclass
class ObjectState:
    """State of a single object."""

    root_state: torch.Tensor
    """Root state ``[pos, quat, lin_vel, ang_vel]``. Shape is (num_envs, 13)."""
    body_names: list[str] | None = None
    """Body names. This is only available for articulation objects."""
    body_state: torch.Tensor | None = None
    """Body state ``[pos, quat, lin_vel, ang_vel]``. Shape is (num_envs, num_bodies, 13). This is only available for articulation objects."""
    joint_pos: torch.Tensor | None = None
    """Joint positions. Shape is (num_envs, num_joints). This is only available for articulation objects."""
    joint_vel: torch.Tensor | None = None
    """Joint velocities. Shape is (num_envs, num_joints). This is only available for articulation objects."""

  @dataclass
class RobotState:
    """State of a single robot."""

    root_state: torch.Tensor
    """Root state ``[pos, quat, lin_vel, ang_vel]``. Shape is (num_envs, 13)."""
    body_names: list[str]
    """Body names."""
    body_state: torch.Tensor
    """Body state ``[pos, quat, lin_vel, ang_vel]``. Shape is (num_envs, num_bodies, 13)."""
    joint_pos: torch.Tensor
    """Joint positions. Shape is (num_envs, num_joints)."""
    joint_vel: torch.Tensor
    """Joint velocities. Shape is (num_envs, num_joints)."""
    joint_pos_target: torch.Tensor
    """Joint positions target. Shape is (num_envs, num_joints)."""
    joint_vel_target: torch.Tensor
    """Joint velocities target. Shape is (num_envs, num_joints)."""
    joint_effort_target: torch.Tensor
    """Joint effort targets. Shape is (num_envs, num_joints)."""

@dataclass
class CameraState:
    """State of a single camera."""

    ## Images
    rgb: torch.Tensor | None
    """RGB image. Shape is (num_envs, H, W, 3)."""
    depth: torch.Tensor | None
    """Depth image. Shape is (num_envs, H, W)."""
    instance_id_seg: torch.Tensor | None = None
    """Instance id segmentation for each pixel. Shape is (num_envs, H, W)."""
    instance_id_seg_id2label: dict[int, str] | None = None
    """Instance id segmentation id to label mapping. Keys are instance ids, values are labels. Go together with :attr:`instance_id_seg`."""
    instance_seg: torch.Tensor | None = None
    """Instance segmentation for each pixel. Shape is (num_envs, H, W).

    .. warning::
        This is experimental and subject to change.
    """
    instance_seg_id2label: dict[int, str] | None = None
    """Instance segmentation id to label mapping. Keys are instance ids, values are labels. Go together with :attr:`instance_seg`.

    .. warning::
        This is experimental and subject to change.
    """

    ## Camera parameters
    pos: torch.Tensor | None = None  # TODO: remove N
    """Position of the camera. Shape is (num_envs, 3)."""
    quat_world: torch.Tensor | None = None  # TODO: remove N
    """Quaternion ``(w, x, y, z)`` of the camera, following the world frame convention. Shape is (num_envs, 4).

    Note:
        World frame convention follows the camera aligned with forward axis +X and up axis +Z.
    """
    intrinsics: torch.Tensor | None = None  # TODO: remove N
    """Intrinsics matrix of the camera. Shape is (num_envs, 3, 3)."""

@dataclass
class TensorState:
    """Tensorized state of the simulation."""

    objects: dict[str, ObjectState]
    """States of all objects."""
    robots: dict[str, RobotState]
    """States of all robots."""
    cameras: dict[str, CameraState]
    """States of all cameras."""
    extras: dict = field(default_factory=dict)
    """States of Extra information"""

To obtain the TensorState from a handler, one can use handler.get_state(mode="tensor") method. The return value will be a TensorState describing the current simulation status.

Then, you can access the TensorState with:

tensor_state = handler.get_state(mode="tensor")
object_pos = tensor_state.objects["ball"].root_state[:, 0:3]  # root_state.shape = (num_envs, 13)

TensorState instances can also be passed into the handler to set the simulation state.

Metasim TaskWrappers by default encourages the user to write the observation() function to return TensorState too.

One disadvantage of TensorState is that it is diffucult for human users to undestand the mapping between tensor indices and actual states. So we also provide a more user-friendly interface.

Dict State#

DictState is more user-friendly compared to TensorState, but scrifices efficiency.

Dof = Dict[str, float]

class DictObjectState(TypedDict):
    """State of the object."""

    pos: torch.Tensor
    rot: torch.Tensor
    vel: torch.Tensor | None
    ang_vel: torch.Tensor | None
    dof_pos: Dof | None
    dof_vel: Dof | None


class DictRobotState(DictObjectState):
    """State of the robot."""

    dof_pos: Dof | None
    dof_vel: Dof | None

    dof_pos_target: Dof | None
    dof_vel_target: Dof | None
    dof_torque: Dof | None

class DictEnvState(TypedDict):
    """State of the environment."""

    objects: dict[str, DictObjectState]
    robots: dict[str, DictRobotState]
    cameras: dict[str, dict[str, torch.Tensor]]
    extras: dict[str, Any]      # States of Extra information

To obtain the DictEnvState from a handler, one can use handler.get_state(mode="dict") method. The return value will be a DictEnvState describing the current simulation status.

Then, you can access the DictEnvState with:

dict_state = handler.get_state(mode="dict")
object_pos = dict_state["objects"]["ball"]["pos"]

DictEnvState instances can also be passed into the handler to set the simulation state.

The disadvantage of DictState is its speed. In some cases, DictState is extremely slow due to frequent dict access and discontinuous memory access.