Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
mbai76 authored Dec 17, 2024
1 parent f1445ce commit 83e76c6
Show file tree
Hide file tree
Showing 55 changed files with 6,046 additions and 9 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ If you discover a potential security issue in this project we ask that you notif

## Licensing

See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
407 changes: 407 additions & 0 deletions LICENSE.txt

Large diffs are not rendered by default.

259 changes: 251 additions & 8 deletions README.md

Large diffs are not rendered by default.

Empty file added __init__.py
Empty file.
89 changes: 89 additions & 0 deletions data/webarena_map_google_fixed_eval.jsonl

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pae/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .onpolicy_train_loop import onpolicy_train_loop
from .worker_collect_loop import worker_collect_loop
from .parallel_utils import remote_collect_trajectories
1 change: 1 addition & 0 deletions pae/algorithms/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .base_trainer import BaseTrainer
43 changes: 43 additions & 0 deletions pae/algorithms/base/base_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from pae.data import DummyDataset
import random

class BaseTrainer():
def __init__(self, agent,
accelerator,
lm_lr: float = 1e-5,
batch_size: int = 4,
max_grad_norm: float = 1.0,):
"""
beta: coefficient for the bc loss
"""
super().__init__()
self.agent = agent
self.batch_size = batch_size
self.max_grad_norm = max_grad_norm
self.accelerator = accelerator

def prepare(self):
return

def actor_loss(self, observation, action, **kwargs):
return {}


def update(self, trajectories, actor_trajectories, iter):
return {}


def validate(self, trajectories):
return {}


def save(self, path):
return


def load(self, path):
return

1 change: 1 addition & 0 deletions pae/algorithms/filteredbc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .trainer import BCTrainer
119 changes: 119 additions & 0 deletions pae/algorithms/filteredbc/trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import torch
import transformers
from tqdm import tqdm
import copy
import random
from torch.utils.data import DataLoader
from pae.data import DummyDataset, DummyImageDataset
def dict_mean(dict_list):
mean_dict = {}
if len(dict_list) > 0:
for key in dict_list[0].keys():
mean_dict[key] = sum(d[key] for d in dict_list) / len(dict_list)
return mean_dict
class BCTrainer():
def __init__(self, agent,
accelerator,
lm_lr: float = 1e-5,
batch_size: int = 4,
max_grad_norm: float = 1.0,
image_use_str: bool = False):
"""
beta: coefficient for the bc loss
"""
super().__init__()
self.agent = agent
self.lm_optimizer = torch.optim.Adam(agent.base.parameters(), lr = lm_lr)
self.batch_size = batch_size
self.max_grad_norm = max_grad_norm
self.accelerator = accelerator
self.image_use_str = image_use_str

def prepare(self):
self.agent.base, self.lm_optimizer = self.accelerator.prepare(self.agent.base, self.lm_optimizer)

def actor_loss(self, observation, action, **kwargs):
# loss = plain_bc_loss(self.agent.model, self.tokenizer, observation, action)
loss = -self.agent.get_log_prob(observation, action).mean()
# loss = self.agent.get_loss(observation, action)
self.accelerator.backward(loss)
return {"bc.loss": loss.detach().cpu().item()}

def actor_validate(self, observation, action, **kwargs):
with torch.no_grad():
loss = -self.agent.get_log_prob(observation, action).mean(dim = 1).mean()
# loss = self.agent.get_loss(observation, action)
return {"validate.bc.loss": loss.detach().cpu().item()}
outputs = self.agent.get_action(observation)
corrects = []
ill_formated = 0
wrong_actions = 0
for output, act in zip(outputs, action):
try:
# corrects.append(output == act)
result = output.split("Action: ")[1] == act.split("Action: ")[1]
if not result:
wrong_actions += 1
print("======> Prediction")
print(output)
print("======> Ground Truth")
print(act)
corrects.append(result)
except:
print("======> Prediction")
print(output)
print("======> Ground Truth")
print(act)
ill_formated += 1
corrects.append(False)
return {"validate.bc.loss": loss.detach().cpu().item(), "validate.bc.action_correct": sum(corrects) / len(corrects),
"validate.bc.ill_formated": ill_formated/len(corrects), "validate.bc.wrong_actions": wrong_actions/len(corrects)}

def update(self, trajectories, actor_trajectories, iter):
self.agent.base.train()
random.seed(iter)
# data = sum([random.sample(trajectories, 1)[0] for _ in range(actor_trajectories)], [])
data = sum(random.sample(trajectories, min(actor_trajectories, len(trajectories))), [])
dataloader = DataLoader(DummyImageDataset(data, self.image_use_str), batch_size=self.batch_size, shuffle=True, num_workers=8)
dataloader = self.accelerator.prepare(dataloader)
info = {}
info_list = []
for sample in tqdm(dataloader, disable=not self.accelerator.is_main_process):
with self.accelerator.accumulate(self.agent.base):
info_list.append(self.actor_loss(**sample))
# if self.accelerator.sync_gradients:
# self.accelerator.clip_grad_norm_(
# self.agent.base.parameters(),
# self.max_grad_norm
# )
self.lm_optimizer.step()
self.lm_optimizer.zero_grad()
# torch.cuda.empty_cache()
# self.accelerator.free_memory()
info.update(dict_mean(info_list))
torch.cuda.empty_cache()
# self.accelerator.free_memory()
return info

def validate(self, trajectories):
self.agent.base.eval()
data = sum(trajectories, [])
dataloader = DataLoader(DummyImageDataset(data, self.image_use_str), batch_size=self.batch_size, shuffle=True, num_workers=8)
dataloader = self.accelerator.prepare(dataloader)
info = {}
info_list = []
with torch.no_grad():
for sample in tqdm(dataloader, disable=not self.accelerator.is_main_process):
info_list.append(self.actor_validate(**sample))
return dict_mean(info_list)

def save(self, path):
self.accelerator.save_state(path)
# torch.save({'model_state_dict': self.accelerator.unwrap_model(self.agent.model).state_dict(),
# 'critic_state_dict': self.accelerator.unwrap_model(self.agent.critic).state_dict(),
# 'target_critic_state_dict': self.accelerator.unwrap_model(self.agent.target_critic).state_dict(),
# 'critic_optimizer_state_dict': self.critic_optimizer.state_dict(),
# 'lm_optimizer_state_dict': self.lm_optimizer.state_dict()}, path)

def load(self, path):
self.accelerator.load_state(path)
Loading

0 comments on commit 83e76c6

Please sign in to comment.