from __future__ import annotations
from dataclasses import MISSING
from typing import Literal
import torch
from loguru import logger as log
from metasim.cfg.objects import BaseObjMetaCfg
from metasim.utils.configclass import configclass
from metasim.utils.math import euler_xyz_from_quat, matrix_from_quat, quat_from_matrix
from metasim.utils.tensor_util import tensor_to_str
from .base_checker import BaseChecker
from .detectors import BaseDetector
try:
from metasim.sim.base import BaseSimHandler
except:
pass
[docs]
@configclass
class DetectedChecker(BaseChecker):
obj_name: str = MISSING
obj_subpath: str | None = None
detector: BaseDetector = MISSING
ignore_if_first_check_success: bool = False
[docs]
def reset(self, handler: BaseSimHandler, env_ids: list[int] | None = None):
self._first_check = torch.ones(handler.num_envs, dtype=torch.bool) # True
self._ignore = torch.zeros(handler.num_envs, dtype=torch.bool) # False
self.detector.reset(handler, env_ids=env_ids)
[docs]
def check(self, handler: BaseSimHandler) -> torch.BoolTensor:
success = self.detector.is_detected(handler, self.obj_name, self.obj_subpath)
if self.ignore_if_first_check_success:
self._ignore[self._first_check & success] = True
self._first_check[self._first_check] = False
success[self._ignore] = False
return success
[docs]
def get_debug_viewers(self) -> list[BaseObjMetaCfg]:
return self.detector.get_debug_viewers()
[docs]
@configclass
class JointPosChecker(BaseChecker):
obj_name: str = MISSING
joint_name: str = MISSING
mode: Literal["ge", "le"] = MISSING
radian_threshold: float = MISSING
[docs]
def check(self, handler: BaseSimHandler) -> torch.BoolTensor:
dof_pos = handler.get_dof_pos(self.obj_name, self.joint_name)
if self.mode == "ge":
return dof_pos >= self.radian_threshold
elif self.mode == "le":
return dof_pos <= self.radian_threshold
else:
raise ValueError(f"Invalid mode: {self.mode}")
[docs]
def get_debug_viewers(self) -> list[BaseObjMetaCfg]:
return []
[docs]
@configclass
class JointPosShiftChecker(BaseChecker):
"""
Check if the joint with `joint_name` of the object with `obj_name` was moved more than `threshold` units.
- `threshold` is negative for moving towards the negative direction and positive for moving towards the positive direction.
"""
obj_name: str = MISSING
joint_name: str = MISSING
threshold: float = MISSING
[docs]
def reset(self, handler: BaseSimHandler, env_ids: list[int] | None = None):
if env_ids is None:
env_ids = list(range(handler.num_envs))
if not hasattr(self, "init_joint_pos"):
self.init_joint_pos = torch.zeros(handler.num_envs, dtype=torch.float32)
self.init_joint_pos[env_ids] = handler.get_dof_pos(self.obj_name, self.joint_name, env_ids=env_ids)
[docs]
def check(self, handler: BaseSimHandler) -> torch.BoolTensor:
cur_joint_pos = handler.get_dof_pos(self.obj_name, self.joint_name)
joint_pos_diff = cur_joint_pos - self.init_joint_pos
log.debug(f"Joint {self.joint_name} of object {self.obj_name} moved {tensor_to_str(joint_pos_diff)} units")
if self.threshold > 0:
return joint_pos_diff >= self.threshold
else:
return joint_pos_diff <= self.threshold
[docs]
@configclass
class UpAxisRotationChecker(BaseChecker):
"""
Check if the object with `obj_name` was rotated away 'target_degree' degrees from the given `axis` (for example, "z", [0,0,1] ) by more than `degree_threshold` degrees.
- `degree_threshold` should be in the range of [0, 180].
- `axis` should be one of "x", "y", "z". default is "z".
"""
## ref: https://github.com/mees/calvin_env/blob/c7377a6485be43f037f4a0b02e525c8c6e8d24b0/calvin_env/envs/tasks.py#L54
obj_name: str = MISSING
degree_threshold: float = MISSING
target_degree: float = MISSING
axis: Literal["x", "y", "z"] = "z"
[docs]
def reset(self, handler: BaseSimHandler, env_ids: list[int] | None = None):
if env_ids is None:
env_ids = list(range(handler.num_envs))
if not hasattr(self, "init_quat"):
self.init_quat = torch.zeros(handler.num_envs, 4, dtype=torch.float32)
self.init_quat[env_ids] = handler.get_rot(self.obj_name, env_ids=env_ids)
[docs]
def check(self, handler: BaseSimHandler) -> torch.BoolTensor:
cur_quat = handler.get_rot(self.obj_name)
cur_rot_mat = matrix_from_quat(cur_quat)
v = {"x": 0, "y": 1, "z": 2}[self.axis]
# Get the rotation around the up axis.
# If cur_rot_mat is batched (e.g. shape [B, 3, 3]), this indexing works over the batch.
up_axis = cur_rot_mat[..., v, :]
# Compute the norm (magnitude) of the up_axis vector along the last dimension.
norm = torch.norm(up_axis, dim=-1)
# Compute the cosine of the angle using batch division.
cos_angle = up_axis[..., 1] / norm
# Calculate the angle in radians then convert to degrees.
angle = torch.acos(cos_angle)
angle = angle * 180.0 / torch.pi
# Compute the absolute difference from the target degree.
delta_angle = torch.abs(angle - self.target_degree)
log.debug(
f"Object {self.obj_name} rotated {angle} degrees away from {self.axis}-axis, the delta is {delta_angle}"
)
return delta_angle <= self.degree_threshold
[docs]
@configclass
class RotationShiftChecker(BaseChecker):
"""
Check if the object with `obj_name` was rotated more than `radian_threshold` radians around the given `axis`.
- `radian_threshold` is negative for clockwise rotations and positive for counter-clockwise rotations.
- `radian_threshold` should be in the range of [-pi, pi].
- `axis` should be one of "x", "y", "z". default is "z".
"""
## ref: https://github.com/mees/calvin_env/blob/c7377a6485be43f037f4a0b02e525c8c6e8d24b0/calvin_env/envs/tasks.py#L54
obj_name: str = MISSING
radian_threshold: float = MISSING
axis: Literal["x", "y", "z"] = "z"
[docs]
def reset(self, handler: BaseSimHandler, env_ids: list[int] | None = None):
if env_ids is None:
env_ids = list(range(handler.num_envs))
if not hasattr(self, "init_quat"):
self.init_quat = torch.zeros(handler.num_envs, 4, dtype=torch.float32)
self.init_quat[env_ids] = handler.get_rot(self.obj_name, env_ids=env_ids)
[docs]
def check(self, handler: BaseSimHandler) -> torch.BoolTensor:
cur_quat = handler.get_rot(self.obj_name)
init_rot_mat = matrix_from_quat(self.init_quat)
cur_rot_mat = matrix_from_quat(cur_quat)
rot_diff = torch.matmul(cur_rot_mat, init_rot_mat.transpose(-1, -2))
x, y, z = euler_xyz_from_quat(quat_from_matrix(rot_diff))
v = {"x": x, "y": y, "z": z}[self.axis]
## Normalize the rotation angle to be within [-pi, pi]
v[v > torch.pi] -= 2 * torch.pi
v[v < -torch.pi] += 2 * torch.pi
assert ((v >= -torch.pi) & (v <= torch.pi)).all()
log.debug(f"Object {self.obj_name} rotated {tensor_to_str(v / torch.pi * 180)} degrees around {self.axis}-axis")
if self.radian_threshold > 0:
return v >= self.radian_threshold
else:
return v <= self.radian_threshold
[docs]
@configclass
class PositionShiftChecker(BaseChecker):
"""
Check if the object with `obj_name` was moved more than `distance` meters in given `axis`.
- `distance` is negative for moving towards the negative direction and positive for moving towards the positive direction.
- `max_distance` is the maximum distance the object can move.
- `axis` should be one of "x", "y", "z".
"""
obj_name: str = MISSING
distance: float = MISSING
bounding_distance: float = 1e2
axis: Literal["x", "y", "z"] = MISSING
[docs]
def reset(self, handler: BaseSimHandler, env_ids: list[int] | None = None):
if env_ids is None:
env_ids = list(range(handler.num_envs))
if not hasattr(self, "init_pos"):
self.init_pos = torch.zeros(handler.num_envs, 3, dtype=torch.float32)
tmp = handler.get_pos(self.obj_name, env_ids=env_ids)
assert tmp.shape == (len(env_ids), 3)
self.init_pos[env_ids] = tmp
[docs]
def check(self, handler: BaseSimHandler) -> torch.BoolTensor:
cur_pos = handler.get_pos(self.obj_name)
if torch.isnan(cur_pos).any():
log.debug(f"Object {self.obj_name} moved to nan position")
return torch.ones(cur_pos.shape[0], dtype=torch.bool)
dim = {"x": 0, "y": 1, "z": 2}[self.axis]
dis_diff = cur_pos - self.init_pos
dim_diff = dis_diff[:, dim]
tot_dis = torch.norm(dis_diff, dim=-1)
log.debug(f"Object {self.obj_name} moved {tensor_to_str(dim_diff)} meters in {self.axis} direction")
if self.distance > 0:
return (dim_diff >= self.distance) * (tot_dis <= self.bounding_distance)
else:
return dim_diff <= self.distance
[docs]
@configclass
class PositionShiftCheckerWithTolerance(BaseChecker):
"""
Check if the object with `obj_name` was moved to `distance` meters in given `axis` with a tolerance of `tolerance`.
- `distance` is negative for moving towards the negative direction and positive for moving towards the positive direction.
- `max_distance` is the maximum distance the object can move.
- `axis` should be one of "x", "y", "z".
"""
## FIXME: this function is redundant with PositionShiftChecker, we should remove it
obj_name: str = MISSING
distance: float = MISSING
bounding_distance: float = 1e2
tolerance: float = 0.01
axis: Literal["x", "y", "z"] = MISSING
[docs]
def reset(self, handler: BaseSimHandler, env_ids: list[int] | None = None):
if env_ids is None:
env_ids = list(range(handler.num_envs))
if not hasattr(self, "init_pos"):
self.init_pos = torch.zeros(handler.num_envs, 3, dtype=torch.float32)
tmp = handler.get_pos(self.obj_name, env_ids=env_ids)
assert tmp.shape == (len(env_ids), 3)
self.init_pos[env_ids] = tmp
[docs]
def check(self, handler: BaseSimHandler) -> torch.BoolTensor:
cur_pos = handler.get_pos(self.obj_name)
cur_vel = handler.get_vel(self.obj_name)
if torch.isnan(cur_pos).any():
log.debug(f"Object {self.obj_name} moved to nan position")
return torch.ones(cur_pos.shape[0], dtype=torch.bool)
dim = {"x": 0, "y": 1, "z": 2}[self.axis]
dis_diff = cur_pos - self.init_pos
dim_diff = dis_diff[:, dim]
tot_dis = torch.norm(dis_diff, dim=-1)
log.debug(f"Object {self.obj_name} moved {tensor_to_str(dim_diff)} meters in {self.axis} direction")
# TODO: velocity check
if self.distance > 0:
return (
(dim_diff >= self.distance - self.tolerance)
* (dim_diff <= self.distance + self.tolerance)
* (tot_dis <= self.bounding_distance)
)
else:
return dim_diff <= self.distance + self.tolerance
[docs]
@configclass
class SlideChecker(BaseChecker):
[docs]
def check(self, handler: BaseSimHandler) -> torch.BoolTensor:
from metasim.utils.humanoid_robot_util import torso_upright
states = handler.get_states()
terminated = []
for state in states:
if torso_upright(state, handler._robot.name) < 0.6:
terminated.append(True)
else:
terminated.append(False)
return torch.tensor(terminated)
[docs]
@configclass
class WalkChecker(BaseChecker):
[docs]
def check(self, handler: BaseSimHandler) -> torch.BoolTensor:
from metasim.utils.humanoid_robot_util import z_height
states = handler.get_states()
terminated = []
for state in states:
if z_height(state, handler._robot.name) < 0.2:
terminated.append(True)
else:
terminated.append(False)
return torch.tensor(terminated)
[docs]
@configclass
class StandChecker(WalkChecker):
pass
[docs]
@configclass
class RunChecker(WalkChecker):
pass
[docs]
@configclass
class CrawlChecker(BaseChecker):
[docs]
def check(self, handler: BaseSimHandler) -> torch.BoolTensor:
states = handler.get_states()
terminated = [False] * len(states)
return torch.tensor(terminated)
[docs]
@configclass
class HurdleChecker(BaseChecker):
[docs]
def check(self, handler: BaseSimHandler) -> torch.BoolTensor:
states = handler.get_states()
terminated = [False] * len(states)
return torch.tensor(terminated)
[docs]
@configclass
class MazeChecker(BaseChecker):
[docs]
def check(self, handler: BaseSimHandler) -> torch.BoolTensor:
from metasim.utils.humanoid_robot_util import z_height
states = handler.get_states()
terminated = []
for state in states:
if z_height(state, handler._robot.name) < 0.2:
terminated.append(True)
else:
terminated.append(False)
return torch.tensor(terminated)
[docs]
@configclass
class PoleChecker(BaseChecker):
[docs]
def check(self, handler: BaseSimHandler) -> torch.BoolTensor:
from metasim.utils.humanoid_robot_util import z_height
states = handler.get_states()
terminated = []
for state in states:
if z_height(state, handler._robot.name) < 0.6:
terminated.append(True)
else:
terminated.append(False)
return torch.tensor(terminated)
[docs]
@configclass
class SitChecker(BaseChecker):
[docs]
def check(self, handler: BaseSimHandler) -> torch.BoolTensor:
from metasim.utils.humanoid_robot_util import z_height
states = handler.get_states()
terminated = []
for state in states:
if z_height(state, handler._robot.name) < 0.5:
terminated.append(True)
else:
terminated.append(False)
return torch.tensor(terminated)
[docs]
@configclass
class BalanceChecker(BaseChecker):
[docs]
def check(self, handler: BaseSimHandler) -> torch.BoolTensor:
from metasim.utils.humanoid_robot_util import z_height
states = handler.get_states()
terminated = []
for state in states:
if z_height(state, handler._robot.name) < 0.8:
terminated.append(True)
else:
terminated.append(False)
return torch.tensor(terminated)
[docs]
@configclass
class StairChecker(BaseChecker):
[docs]
def check(self, handler: BaseSimHandler) -> torch.BoolTensor:
from metasim.utils.humanoid_robot_util import torso_upright
states = handler.get_states()
terminated = []
for state in states:
if torso_upright(state, handler._robot.name) < 0.1:
terminated.append(True)
else:
terminated.append(False)
return torch.tensor(terminated)
[docs]
@configclass
class PushChecker(BaseChecker):
[docs]
def check(self, handler: BaseSimHandler) -> torch.BoolTensor:
states = handler.get_states()
terminated = []
for state in states:
# Get box position
box_pos = state["object"]["pos"]
# Get destination position
dest_pos = state["destination"]["pos"]
# Calculate distance between box and destination
dgoal = torch.norm(box_pos - dest_pos)
# Terminate when dgoal < 0.05 (success)
if dgoal < 0.05:
terminated.append(True)
else:
terminated.append(False)
return torch.tensor(terminated)