Source code for scope_rl.ope.weight_value_learning.function

# Copyright (c) 2023, Haruka Kiyohara, Ren Kishimoto, HAKUHODO Technologies Inc., and Hanjuku-kaso Co., Ltd. All rights reserved.
# Licensed under the Apache 2.0 License.

"""Weight and Value Functions."""
import torch
from torch import nn
import torch.nn.functional as F
from pytorch_revgrad import RevGrad


[docs]class VFunction(nn.Module): """Value Function (for both discrete and continuous action space). Bases: :class:`torch.nn.Module` Imported as: :class:`scope_rl.ope.weight_value_learning.function.VFunction` Parameters ------- state_dim: int (> 0) Dimensions of the state space. hidden_dim: int, default=100 (> 0) Hidden dimension of the network. """ def __init__( self, state_dim: int, hidden_dim: int = 100, ): super().__init__() self.fc1 = nn.Linear(state_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, 1) def forward( self, state: torch.Tensor, ): x = F.relu(self.fc1(state)) return self.fc2(x).flatten()
[docs]class StateWeightFunction(nn.Module): """State Weight Function (for both discrete and continuous action space). Bases: :class:`torch.nn.Module` Imported as: :class:`scope_rl.ope.weight_value_learning.function.StateWeightFunction` Parameters ------- state_dim: int (> 0) Dimensions of the state space. hidden_dim: int, default=100 (> 0) Hidden dimension of the network. enable_gradient_reversal: bool = False Whether to enable gradient reversal layer (for loss maximization). """ def __init__( self, state_dim: int, hidden_dim: int = 100, enable_gradient_reversal: bool = False, ): super().__init__() self.enable_gradient_reversal = enable_gradient_reversal self.fc1 = nn.Linear(state_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, 1) self.grl = RevGrad() def forward( self, state: torch.Tensor, ): x = F.relu(self.fc1(state)) x = F.softplus(self.fc2(x)) if self.enable_gradient_reversal: x = self.grl(x) return x.flatten()
[docs]class DiscreteQFunction(nn.Module): """Q Function (for discrete action space). Bases: :class:`torch.nn.Module` Imported as: :class:`scope_rl.ope.weight_value_learning.function.DiscreteQFunction` Parameters ------- n_actions: int (> 0) Number of actions. state_dim: int (> 0) Dimensions of the state space. hidden_dim: int, default=100 (> 0) Hidden dimension of the network. device: str, default="cuda:0" Specifies device used for torch. """ def __init__( self, n_actions: int, state_dim: int, hidden_dim: int = 100, device: str = "cuda:0", ): super().__init__() self.fc1 = nn.Linear(state_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, n_actions) self.action_onehot = torch.eye(n_actions, device=device) def forward( self, state: torch.Tensor, action: torch.Tensor, ): action = self.action_onehot[action] x = F.relu(self.fc1(state)) values = self.fc2(x) return (values * action).sum(axis=1) def all( self, state: torch.Tensor, ): x = F.relu(self.fc1(state)) return self.fc2(x) def max( self, state: torch.Tensor, ): x = F.relu(self.fc1(state)) values = self.fc2(x) return torch.max(values, dim=1)[0] def argmax( self, state: torch.Tensor, ): x = F.relu(self.fc1(state)) values = self.fc2(x) return torch.max(values, dim=1)[1] def expectation( self, state: torch.Tensor, action_distribution: torch.Tensor, ): x = F.relu(self.fc1(state)) values = self.fc2(x) return (values * action_distribution).sum(axis=1)
[docs]class ContinuousQFunction(nn.Module): """Q Function (for continuous action space). Bases: :class:`torch.nn.Module` Imported as: :class:`scope_rl.ope.weight_value_learning.function.ContinuousQFunction` Parameters ------- action_dim: int (> 0) Dimensions of the action space. state_dim: int (> 0) Dimensions of the state space. hidden_dim: int, default=100 (> 0) Hidden dimension of the network. """ def __init__( self, action_dim: int, state_dim: int, hidden_dim: int = 100, ): super().__init__() self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, 1) def forward( self, state: torch.Tensor, action: torch.Tensor, ): x = torch.cat((state, action), dim=1) x = F.relu(self.fc1(x)) values = self.fc2(x) return (values * action).sum(axis=1)
[docs]class DiscreteStateActionWeightFunction(nn.Module): """State Action Weight Function (for discrete action space). Bases: :class:`torch.nn.Module` Imported as: :class:`scope_rl.ope.weight_value_learning.function.DiscreteStateActionWeightFunction` Parameters ------- n_actions: int (> 0) Number of actions. state_dim: int (> 0) Dimensions of the state space. hidden_dim: int, default=100 (> 0) Hidden dimension of the network. enable_gradient_reversal: bool = False Whether to enable gradient reversal layer (for loss maximization). device: str, default="cuda:0" Specifies device used for torch. """ def __init__( self, n_actions: int, state_dim: int, hidden_dim: int = 100, enable_gradient_reversal: bool = False, device: str = "cuda:0", ): super().__init__() self.enable_gradient_reversal = enable_gradient_reversal self.fc1 = nn.Linear(state_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, n_actions) self.action_onehot = torch.eye(n_actions, device=device) self.grl = RevGrad() def forward( self, state: torch.Tensor, action: torch.Tensor, ): action = self.action_onehot[action] x = F.relu(self.fc1(state)) values = F.softplus(self.fc2(x)) if self.enable_gradient_reversal: values = self.grl(values) return (values * action).sum(axis=1)
[docs]class ContinuousStateActionWeightFunction(nn.Module): """State Action Weight Function (for continuous action space). Bases: :class:`torch.nn.Module` Imported as: :class:`scope_rl.ope.weight_value_learning.function.ContinuousStateActionWeightFunction` Parameters ------- action_dim: int (> 0) Dimensions of the action space. state_dim: int (> 0) Dimensions of the state space. hidden_dim: int, default=100 (> 0) Hidden dimension of the network. enable_gradient_reversal: bool = False Whether to enable gradient reversal layer (for loss maximization). """ def __init__( self, action_dim: int, state_dim: int, hidden_dim: int = 100, enable_gradient_reversal: bool = False, ): super().__init__() self.enable_gradient_reversal = enable_gradient_reversal self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, 1) self.grl = RevGrad() def forward( self, state: torch.Tensor, action: torch.Tensor, ): x = torch.cat((state, action), dim=1) x = F.relu(self.fc1(x)) x = F.softplus(self.fc2(x)) if self.enable_gradient_reversal: x = self.grl(x) return x.squeeze()