You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
While testing the official sample code, I encountered some issues. For instance, when I used this code to encode the dataset, a large number of all-zero tensors appeared in the final saved tensors, with a ratio of approximately 5 million out of 20 million.
importloggingimportosimportnumpyasnpimportdatasetsfromtorch.utils.dataimportDataLoaderfromtqdmimporttqdmfromsentence_transformersimportLoggingHandler, SentenceTransformerimportargparselogging.basicConfig(
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
)
defparse_dataset(dataset_str):
ifdataset_str.endswith('.tsv'):
# 尝试读取TSV文件dataset=datasets.load_dataset("csv", data_files=dataset_str, delimiter="\t", num_proc=32)['train']
else:
dataset=datasets.load_from_disk(dataset_str)
returndatasetdefparse_args():
parser=argparse.ArgumentParser(description="Process datasets with padding and multiprocessing.")
parser.add_argument("--model", default="stella", help="Model name or path.")
parser.add_argument("--dataset", required=True, help="Path to the dataset.")
parser.add_argument("--batch_size", type=int, default=1024, help="Batch size per worker.")
parser.add_argument("--output_dir", default="./embeddings", help="Directory to save embeddings.")
returnparser.parse_args()
if__name__=="__main__":
# Set paramsargs=parse_args()
# 载入数据集dataset=parse_dataset(args.dataset)
# 加载模型model=SentenceTransformer(args.model, trust_remote_code=True)
# 启动多进程池pool=model.start_multi_process_pool()
# 设置DataLoaderdataloader=DataLoader(
dataset=dataset,
batch_size=8192,
num_workers=8, # 根据CPU核心数调整pin_memory=True, # 启用pin_memoryprefetch_factor=4, # 增加预取批次数量
)
# Ensure the output directory existsifnotos.path.exists(args.output_dir):
os.makedirs(args.output_dir)
# 用于存储所有批次的嵌入all_embeddings= []
fori, batchinenumerate(tqdm(dataloader)):
sentences=batch["text"]
batch_emb=model.encode_multi_process(sentences, pool, batch_size=2048)
print(f"Embeddings computed for batch {i+1}. Shape: {batch_emb.shape}")
all_embeddings.append(batch_emb)
all_embeddings=np.vstack(all_embeddings)
print(f"Total embeddings shape: {all_embeddings.shape}")
output_file=os.path.join(args.output_dir, "all_embeddings.npy")
np.save(output_file, all_embeddings)
print(f"All embeddings saved at {output_file}")
model.stop_multi_process_pool(pool)
print("All embeddings have been processed and saved.")
I detected this issue by loading the already saved tensors.
are_close are false at specific row(like 1025 or 10000011)
The following code output 5m length
zero_mask=np.all(all_embeddings==0, axis=1)
zero_indices=np.where(zero_mask)[0]
print(f"Found {len(zero_indices)} zero embeddings at indices: {zero_indices}")
What could be the possible reasons for this issue? Thank you for your assistance.
The text was updated successfully, but these errors were encountered:
While testing the official sample code, I encountered some issues. For instance, when I used this code to encode the dataset, a large number of all-zero tensors appeared in the final saved tensors, with a ratio of approximately 5 million out of 20 million.
I detected this issue by loading the already saved tensors.
are_close are false at specific row(like 1025 or 10000011)
The following code output 5m length
What could be the possible reasons for this issue? Thank you for your assistance.
The text was updated successfully, but these errors were encountered: