Source code for basicgym.envs.simulator.base
# Copyright (c) 2023, Haruka Kiyohara, Ren Kishimoto, HAKUHODO Technologies Inc., and Hanjuku-kaso Co., Ltd. All rights reserved.
# Licensed under the Apache 2.0 License.
"""Abstract Base Class for Simulation."""
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
import numpy as np
[docs]@dataclass
class BaseStateTransitionFunction(metaclass=ABCMeta):
"""Base class to define the state transition function.
Imported as: :class:`basicgym.BaseStateTransitionFunction`
"""
[docs] @abstractmethod
def step(
self,
state: np.ndarray,
action: np.ndarray,
) -> np.ndarray:
"""Update the state based on the presented action.
Parameters
-------
state: array-like of shape (state_dim, )
Current state.
action: array-like of shape (action_dim, )
Indicating the action chosen by the agent.
Return
-------
state: array-like of shape (state_dim, )
Next state.
"""
raise NotImplementedError
[docs]@dataclass
class BaseRewardFunction(metaclass=ABCMeta):
"""Base class to define the expected immediate reward function.
Imported as: :class:`basicgym.BaseRewardFunction`
"""
[docs] @abstractmethod
def mean_reward_function(
self,
state: np.ndarray,
action: np.ndarray,
) -> float:
"""Expected immediate reward function
Parameters
-------
state: array-like of shape (state_dim, )
State in the RL environment.
action: array-like of shape (action_dim, )
Indicating the action chosen by the agent.
Return
-------
mean_reward_function: float
Expected immediate reward function conditioned on the state and action.
"""
raise NotImplementedError
[docs] def sample_reward(
self,
state: np.ndarray,
action: np.ndarray,
):
"""Sample reward."""
mean_reward_function = self.mean_reward_function(state, action)
if self.reward_type == "continuous":
reward = self.random_.normal(
loc=mean_reward_function, scale=self.reward_std
)
else:
reward = self.random_.binominal(1, p=mean_reward_function)
return reward