Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions pufferlib/gp_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""
Drop-in replacement for GPyTorch exact GP training, CUDA implementation only.
"""

import numpy as np
import torch
import torch.nn as nn

from pufferlib import _C

class _MLLFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, params, backend):
ctx.backend = backend
return params.new_tensor(backend.log_marginal_likelihood)

@staticmethod
def backward(ctx, grad_output):
grads = torch.tensor(ctx.backend.mll_grad(),
dtype=grad_output.dtype, device=grad_output.device)
return grad_output * grads, None


def _np32(x):
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
return np.ascontiguousarray(x, dtype=np.float32)


class GaussianProcess(nn.Module):
def __init__(self, dim, capacity,
lengthscale=1.0, outputscale=1.0, noise=1e-2, offset=1.0,
use_cuda=True):
super().__init__()
self._backend = _C.GaussianProcess(dim=dim, capacity=capacity,
lengthscale=lengthscale, outputscale=outputscale,
noise=noise, offset=offset)
self.raw_lengthscale = nn.Parameter(
torch.tensor(np.asarray(self._backend.raw_lengthscale), dtype=torch.float64))
self.raw_outputscale = nn.Parameter(
torch.tensor(self._backend.raw_outputscale, dtype=torch.float64))
self.raw_noise = nn.Parameter(
torch.tensor(self._backend.raw_noise, dtype=torch.float64))
self.raw_offset = nn.Parameter(
torch.tensor(self._backend.raw_offset, dtype=torch.float64))

@property
def lengthscale(self):
return np.asarray(self._backend.lengthscale)

@property
def outputscale(self): return self._backend.outputscale

@property
def noise(self): return self._backend.noise

@property
def offset(self): return self._backend.offset

@property
def log_marginal_likelihood(self): return self._backend.log_marginal_likelihood

@property
def lengthscale_range(self):
ells = self.lengthscale
return float(np.min(ells)), float(np.max(ells))

@property
def n(self): return self._backend.n

@property
def dim(self): return self._backend.dim

@property
def capacity(self): return self._backend.capacity

def fit(self, X, y):
self._sync()
self._backend.fit(_np32(X), _np32(y))

def recompute(self):
self._sync()
self._backend.recompute()

def mll(self, recompute=True):
self._sync()
if recompute:
self._backend.recompute()
params = torch.cat([self.raw_lengthscale,
self.raw_outputscale.unsqueeze(0),
self.raw_noise.unsqueeze(0),
self.raw_offset.unsqueeze(0)])
return _MLLFunction.apply(params, self._backend)

def predict(self, Xs):
self._sync()
means, vars_ = self._backend.predict(_np32(Xs))
return torch.from_numpy(means), torch.from_numpy(vars_)

def eval(self):
result = super().eval()
if hasattr(self, '_backend'):
self._sync()
self._backend.recompute()
return result

def save(self, path):
self._sync()
self._backend.save(path)

@classmethod
def load(cls, path, extra_cap=0, use_cuda=True):
obj = cls.__new__(cls)
nn.Module.__init__(obj)
obj._backend = _C.GaussianProcess.load(path, extra_cap)
obj.raw_lengthscale = nn.Parameter(
torch.tensor(np.asarray(obj._backend.raw_lengthscale), dtype=torch.float64))
obj.raw_outputscale = nn.Parameter(
torch.tensor(obj._backend.raw_outputscale, dtype=torch.float64))
obj.raw_noise = nn.Parameter(
torch.tensor(obj._backend.raw_noise, dtype=torch.float64))
obj.raw_offset = nn.Parameter(
torch.tensor(obj._backend.raw_offset, dtype=torch.float64))
return obj

def __repr__(self):
ells = self.lengthscale
if len(ells) == 1:
ell_s = f"{ells[0]:.3g}"
else:
ell_s = f"[{np.min(ells):.3g}..{np.max(ells):.3g}]"
return (f"<GaussianProcess dim={self.dim} n={self.n} cap={self.capacity} "
f"ell={ell_s} sf={self.outputscale:.3g} "
f"noise={self.noise:.3g}>")

def _sync(self):
self._backend.raw_lengthscale = self.raw_lengthscale.detach().cpu().numpy().astype(np.float32)
self._backend.raw_outputscale = float(self.raw_outputscale.item())
self._backend.raw_noise = float(self.raw_noise.item())
self._backend.raw_offset = float(self.raw_offset.item())
145 changes: 63 additions & 82 deletions pufferlib/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,13 @@
import pufferlib

import torch
import gpytorch
from gpytorch.models import ExactGP
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.kernels import MaternKernel, PolynomialKernel, ScaleKernel, AdditiveKernel
from gpytorch.means import ConstantMean
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.priors import LogNormalPrior
from scipy.optimize import minimize
from scipy.stats.qmc import Sobol
from scipy.spatial import KDTree
from sklearn.linear_model import LogisticRegression

from pufferlib.gp_torch import GaussianProcess

EPSILON = 1e-6

def unroll_nested_dict(d):
Expand Down Expand Up @@ -146,7 +141,8 @@ def _params_from_puffer_sweep(sweep_config, only_include=None):

for name, param in sweep_config.items():
if name in ('method', 'metric', 'metric_distribution', 'goal', 'downsample', 'use_gpu', 'prune_pareto',
'sweep_only', 'max_suggestion_cost', 'early_stop_quantile', 'gpus', 'max_runs'):
'sweep_only', 'max_suggestion_cost', 'early_stop_quantile', 'gpus', 'max_runs',
'match_enemy_model_path', 'match_num_games', 'match_enemy_hidden_size', 'match_enemy_num_layers'):
continue

assert isinstance(param, dict), f'Param {name} is not a dict'
Expand Down Expand Up @@ -396,53 +392,41 @@ def early_stop(self, logs, target_key):
return False


class ExactGPModel(ExactGP):
def __init__(self, train_x, train_y, likelihood, x_dim):
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = ConstantMean()
# Matern 3/2 kernel (equivalent to Pyro's Matern32)
matern_kernel = MaternKernel(nu=1.5, ard_num_dims=x_dim)

# NOTE: setting this constraint changes GP behavior, including the lengthscale
# even though the lengthscale is well within the range ... Commenting out for now.
# lengthscale_constraint = gpytorch.constraints.Interval(0.01, 10.0)
# matern_kernel = MaternKernel(nu=1.5, ard_num_dims=x_dim, lengthscale_constraint=lengthscale_constraint)
_NOISE_PRIOR_MU = math.log(1e-2) # LogNormal prior on noise, from HEBO
_NOISE_PRIOR_SIGMA = 0.5

linear_kernel = PolynomialKernel(power=1)
self.covar_module = ScaleKernel(AdditiveKernel(linear_kernel, matern_kernel))
def _noise_log_prior(raw_noise_param):
# LogNormal(mu, sigma) prior in constrained space, with softplus change-of-variables.
# log p(raw) = log p_lognormal(softplus(raw) + lb) + log softplus_grad(raw)
noise = torch.nn.functional.softplus(raw_noise_param) + 1e-4
log_noise = torch.log(noise)
log_p = -0.5 * ((log_noise - _NOISE_PRIOR_MU) / _NOISE_PRIOR_SIGMA) ** 2 \
- log_noise - math.log(_NOISE_PRIOR_SIGMA)
log_jac = -torch.log1p(torch.exp(-raw_noise_param)) # log sigmoid = log softplus_grad
return log_p + log_jac

def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

@property
def lengthscale_range(self):
# Get lengthscale from MaternKernel
lengthscale = self.covar_module.base_kernel.kernels[1].lengthscale.tolist()[0]
return min(lengthscale), max(lengthscale)
def _fit_gp(gp_model, optimizer, train_x, train_y, training_iter=50):
if isinstance(train_x, torch.Tensor):
train_x = train_x.cpu().numpy()
if isinstance(train_y, torch.Tensor):
train_y = train_y.cpu().numpy()

def train_gp_model(model, likelihood, mll, optimizer, train_x, train_y, training_iter=50):
model.train()
likelihood.train()
model.set_train_data(inputs=train_x, targets=train_y, strict=False)
gp_model.train()
gp_model.fit(train_x, train_y)

loss = None
for _ in range(training_iter):
try:
optimizer.zero_grad()
output = model(train_x)
loss = -mll(output, train_y)
loss = -gp_model.mll() - _noise_log_prior(gp_model.raw_noise)
loss.backward()
optimizer.step()
loss = loss.detach()

except gpytorch.utils.errors.NotPSDError:
# It's rare but it does happen. Hope it's a transient issue.
except Exception:
break

model.eval()
likelihood.eval()
gp_model.eval()
return loss.item() if loss is not None else 0


Expand Down Expand Up @@ -592,36 +576,40 @@ def __init__(self,
self.stop_threshold_model = RobustLogCostModel(quantile=sweep_config['early_stop_quantile'])
self.upper_cost_threshold = -np.inf

# Use 64 bit for GP regression
with default_tensor_dtype(torch.float64):
# Params taken from HEBO: https://arxiv.org/abs/2012.03826
noise_prior = LogNormalPrior(math.log(1e-2), 0.5)

# Create dummy data for model initialization on the selected device
dummy_x = torch.ones((1, self.hyperparameters.num), device=self.device)
dummy_y = torch.zeros(1, device=self.device)
# Score GP
self.likelihood_score = GaussianLikelihood(noise_prior=deepcopy(noise_prior)).to(self.device)
self.gp_score = ExactGPModel(dummy_x, dummy_y, self.likelihood_score, self.hyperparameters.num).to(self.device)
self.mll_score = ExactMarginalLogLikelihood(self.likelihood_score, self.gp_score).to(self.device)
self.score_opt = torch.optim.Adam(self.gp_score.parameters(), lr=self.gp_learning_rate, amsgrad=True)
# Score GP
self.gp_score = GaussianProcess(dim=self.hyperparameters.num, capacity=self.gp_max_obs, use_cuda=_use_gpu)
self.score_opt = torch.optim.Adam(self.gp_score.parameters(), lr=self.gp_learning_rate, amsgrad=True)

# Cost GP
self.likelihood_cost = GaussianLikelihood(noise_prior=deepcopy(noise_prior)).to(self.device)
self.gp_cost = ExactGPModel(dummy_x, dummy_y, self.likelihood_cost, self.hyperparameters.num).to(self.device)
self.mll_cost = ExactMarginalLogLikelihood(self.likelihood_cost, self.gp_cost).to(self.device)
self.cost_opt = torch.optim.Adam(self.gp_cost.parameters(), lr=self.gp_learning_rate, amsgrad=True)
# Cost GP
self.gp_cost = GaussianProcess(dim=self.hyperparameters.num, capacity=self.gp_max_obs, use_cuda=_use_gpu)
self.cost_opt = torch.optim.Adam(self.gp_cost.parameters(), lr=self.gp_learning_rate, amsgrad=True)

# Buffers for GP training and inference
self.gp_params_buffer = torch.empty(self.gp_max_obs, self.hyperparameters.num, dtype=torch.float64, device=self.device)
self.gp_score_buffer = torch.empty(self.gp_max_obs, dtype=torch.float64, device=self.device)
self.gp_cost_buffer = torch.empty(self.gp_max_obs, dtype=torch.float64, device=self.device)
self.infer_batch_buffer = torch.empty(self.infer_batch_size, self.hyperparameters.num, dtype=torch.float64, device=self.device)

# Buffers for GP training and inference
self.gp_params_buffer = torch.empty(self.gp_max_obs, self.hyperparameters.num, device=self.device)
self.gp_score_buffer = torch.empty(self.gp_max_obs, device=self.device)
self.gp_cost_buffer = torch.empty(self.gp_max_obs, device=self.device)
self.infer_batch_buffer = torch.empty(self.infer_batch_size, self.hyperparameters.num, device=self.device)
_CUDA_ATTRS = ('gp_score', 'gp_cost', 'score_opt', 'cost_opt',
'gp_params_buffer', 'gp_score_buffer',
'gp_cost_buffer', 'infer_batch_buffer')

def __getstate__(self):
state = self.__dict__.copy()
for attr in self._CUDA_ATTRS:
state.pop(attr, None)
return state

def __setstate__(self, state):
self.__dict__.update(state)
for attr in self._CUDA_ATTRS:
if attr not in self.__dict__:
self.__dict__[attr] = None

def to(self, device):
self.device = torch.device(device)
for attr in ('gp_score', 'gp_cost', 'likelihood_score', 'likelihood_cost',
'mll_score', 'mll_cost', 'gp_params_buffer', 'gp_score_buffer',
for attr in ('gp_score', 'gp_cost',
'gp_params_buffer', 'gp_score_buffer',
'gp_cost_buffer', 'infer_batch_buffer'):
setattr(self, attr, getattr(self, attr).to(self.device))
for opt in (self.score_opt, self.cost_opt):
Expand Down Expand Up @@ -713,10 +701,8 @@ def _train_gp_models(self):
log_c_norm_tensor = self.gp_cost_buffer[:num_sampled]
log_c_norm_tensor.copy_(torch.from_numpy(log_c_norm))

with warnings.catch_warnings():
warnings.simplefilter("ignore", gpytorch.utils.warnings.NumericalWarning)
score_loss = train_gp_model(self.gp_score, self.likelihood_score, self.mll_score, self.score_opt, params_tensor, y_norm_tensor, training_iter=self.gp_training_iter)
cost_loss = train_gp_model(self.gp_cost, self.likelihood_cost, self.mll_cost, self.cost_opt, params_tensor, log_c_norm_tensor, training_iter=self.gp_training_iter)
score_loss = _fit_gp(self.gp_score, self.score_opt, params_tensor, y_norm_tensor, training_iter=self.gp_training_iter)
cost_loss = _fit_gp(self.gp_cost, self.cost_opt, params_tensor, log_c_norm_tensor, training_iter=self.gp_training_iter)

return score_loss, cost_loss

Expand Down Expand Up @@ -818,30 +804,25 @@ def suggest(self, fill, fixed_total_timesteps=None):
# Batch predictions to avoid GPU OOM for large number of suggestions
gp_y_norm_list, gp_log_c_norm_list = [], []

with torch.no_grad(), gpytorch.settings.fast_pred_var(), warnings.catch_warnings():
warnings.simplefilter("ignore", gpytorch.utils.warnings.NumericalWarning)

# Create a reusable buffer on the device to avoid allocating a huge tensor
with torch.no_grad():
for i in range(0, len(suggestions), self.infer_batch_size):
batch_numpy = suggestions[i:i+self.infer_batch_size]
current_batch_size = len(batch_numpy)

# Use a slice of the buffer if the current batch is smaller
batch_tensor = self.infer_batch_buffer[:current_batch_size]
batch_tensor.copy_(torch.from_numpy(batch_numpy))

try:
# Score and cost prediction
pred_y_mean = self.likelihood_score(self.gp_score(batch_tensor)).mean.cpu()
pred_c_mean = self.likelihood_cost(self.gp_cost(batch_tensor)).mean.cpu()

pred_y_mean, _ = self.gp_score.predict(batch_tensor)
pred_c_mean, _ = self.gp_cost.predict(batch_tensor)
except RuntimeError:
# Handle numerical errors during GP prediction
pred_y_mean, pred_c_mean = torch.zeros(current_batch_size)

gp_y_norm_list.append(pred_y_mean.cpu())
gp_log_c_norm_list.append(pred_c_mean.cpu())
pred_y_mean = torch.zeros(current_batch_size, dtype=torch.float64)
pred_c_mean = torch.zeros(current_batch_size, dtype=torch.float64)

gp_y_norm_list.append(pred_y_mean)
gp_log_c_norm_list.append(pred_c_mean)
del pred_y_mean, pred_c_mean

# Concatenate results from all batches
Expand Down
Loading