Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: clip_back支持cn_clip #358

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions clip_retrieval/clip_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,19 @@


LOGGER = logging.getLogger(__name__)
# DEBUG:10, INFO:20, WARNING:30(default), ERROR:40, CRITICAL:50
LOGGER.setLevel(logging.INFO)
print(f"The current log level is: {LOGGER.getEffectiveLevel()}")

# NOTE: 在Python的logging模块中,日志记录器(Logger)负责决定哪些日志消息要被处理,而日志处理器(Handler)则决定了日志消息的去处。
# 创建一个StreamHandler,将日志消息输出到控制台
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
# 创建一个Formatter,定义日志消息的格式
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
# 将Handler添加到Logger
LOGGER.addHandler(console_handler)


for coll in list(REGISTRY._collector_to_names.keys()): # pylint: disable=protected-access
Expand Down Expand Up @@ -214,9 +227,11 @@ def compute_query(
use_mclip,
aesthetic_score,
aesthetic_weight,
**kwargs
):
"""compute the query embedding"""
import torch # pylint: disable=import-outside-toplevel
verbose = kwargs.pop("verbose", False)

if text_input is not None and text_input != "":
if use_mclip:
Expand All @@ -230,6 +245,11 @@ def compute_query(
text_features = clip_resource.model.encode_text(text)
text_features /= text_features.norm(dim=-1, keepdim=True)
query = text_features.cpu().to(torch.float32).detach().numpy()
if verbose:
LOGGER.info(f'text:{text_input}')
LOGGER.info(f'text_token:{text}')
LOGGER.info(f'normalized_text_features:{text_features[0][0:16]}')

elif image_input is not None or image_url_input is not None:
if image_input is not None:
binary_data = base64.b64decode(image_input)
Expand Down Expand Up @@ -341,9 +361,10 @@ def post_filter(
return to_remove

def knn_search(
self, query, modality, num_result_ids, clip_resource, deduplicate, use_safety_model, use_violence_detector
self, query, modality, num_result_ids, clip_resource, deduplicate, use_safety_model, use_violence_detector, **kwargs
):
"""compute the knn search"""
verbose = kwargs.pop("verbose", False)

image_index = clip_resource.image_index
text_index = clip_resource.text_index
Expand Down Expand Up @@ -377,6 +398,10 @@ def knn_search(
result_distances = distances[0][:nb_results]
result_embeddings = embeddings[0][:nb_results]
result_embeddings = normalized(result_embeddings)
if verbose:
LOGGER.info(f'embeddings:{embeddings}')
LOGGER.info(f'normalized_result_embeddings:{result_embeddings[0][0:16]}')

local_indices_to_remove = self.post_filter(
clip_resource.safety_model,
result_embeddings,
Expand Down Expand Up @@ -432,6 +457,7 @@ def query(
use_violence_detector=False,
aesthetic_score=None,
aesthetic_weight=None,
**kwargs
):
"""implement the querying functionality of the knn service: from text and image to nearest neighbors"""

Expand All @@ -451,6 +477,7 @@ def query(
use_mclip=use_mclip,
aesthetic_score=aesthetic_score,
aesthetic_weight=aesthetic_weight,
**kwargs
)
distances, indices = self.knn_search(
query,
Expand All @@ -460,6 +487,7 @@ def query(
deduplicate=deduplicate,
use_safety_model=use_safety_model,
use_violence_detector=use_violence_detector,
**kwargs
)
if len(distances) == 0:
return []
Expand Down Expand Up @@ -489,6 +517,7 @@ def post(self):
aesthetic_score = int(aesthetic_score) if aesthetic_score != "" else None
aesthetic_weight = json_data.get("aesthetic_weight", "")
aesthetic_weight = float(aesthetic_weight) if aesthetic_weight != "" else None
verbose = json_data.get("verbose", False)
return self.query(
text_input,
image_input,
Expand All @@ -504,6 +533,7 @@ def post(self):
use_violence_detector,
aesthetic_score,
aesthetic_weight,
verbose=verbose
)


Expand Down Expand Up @@ -792,6 +822,7 @@ class ClipOptions:

indice_folder: str
clip_model: str
clip_cache_path: str
enable_hdf5: bool
enable_faiss_memory_mapping: bool
columns_to_return: List[str]
Expand All @@ -808,6 +839,7 @@ def dict_to_clip_options(d, clip_options):
return ClipOptions(
indice_folder=d["indice_folder"] if "indice_folder" in d else clip_options.indice_folder,
clip_model=d["clip_model"] if "clip_model" in d else clip_options.clip_model,
clip_cache_path=d["clip_cache_path"] if "clip_cache_path" in d else clip_options.clip_cache_path,
enable_hdf5=d["enable_hdf5"] if "enable_hdf5" in d else clip_options.enable_hdf5,
enable_faiss_memory_mapping=d["enable_faiss_memory_mapping"]
if "enable_faiss_memory_mapping" in d
Expand Down Expand Up @@ -865,7 +897,7 @@ def load_clip_index(clip_options):
from all_clip import load_clip # pylint: disable=import-outside-toplevel

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess, tokenizer = load_clip(clip_options.clip_model, use_jit=clip_options.use_jit, device=device)
model, preprocess, tokenizer = load_clip(clip_options.clip_model, use_jit=clip_options.use_jit, device=device, clip_cache_path=clip_options.clip_cache_path)

if clip_options.enable_mclip_option:
model_txt_mclip = load_mclip(clip_options.clip_model)
Expand Down Expand Up @@ -961,6 +993,7 @@ def clip_back(
url_column="url",
enable_mclip_option=True,
clip_model="ViT-B/32",
clip_cache_path=None,
use_jit=True,
use_arrow=False,
provide_safety_model=False,
Expand All @@ -976,6 +1009,7 @@ def clip_back(
clip_options=ClipOptions(
indice_folder="",
clip_model=clip_model,
clip_cache_path=clip_cache_path,
enable_hdf5=enable_hdf5,
enable_faiss_memory_mapping=enable_faiss_memory_mapping,
columns_to_return=columns_to_return,
Expand Down
9 changes: 7 additions & 2 deletions clip_retrieval/clip_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,15 @@ def __init__(
self.deduplicate = deduplicate
self.use_safety_model = use_safety_model
self.use_violence_detector = use_violence_detector
print('-' * 64)
print('hi verbose')

def query(
self,
text: Optional[str] = None,
image: Optional[str] = None,
embedding_input: Optional[list] = None,
verbose = False,
) -> List[Dict]:
"""
Given text or image/s, search for other captions/images that are semantically similar.
Expand All @@ -81,10 +84,10 @@ def query(
if text and image:
raise ValueError("Only one of text or image can be provided.")
if text:
return self.__search_knn_api__(text=text)
return self.__search_knn_api__(text=text, verbose=verbose)
elif image:
if image.startswith("http"):
return self.__search_knn_api__(image_url=image)
return self.__search_knn_api__(image_url=image, verbose=verbose)
else:
assert Path(image).exists(), f"{image} does not exist."
return self.__search_knn_api__(image=image)
Expand All @@ -99,6 +102,7 @@ def __search_knn_api__(
image: Optional[str] = None,
image_url: Optional[str] = None,
embedding_input: Optional[list] = None,
verbose=False,
) -> List:
"""
This function is used to send the request to the knn service.
Expand Down Expand Up @@ -147,6 +151,7 @@ def __search_knn_api__(
"num_images": self.num_images,
# num_results_ids is hardcoded to the num_images parameter.
"num_result_ids": self.num_images,
"verbose": verbose,
}
),
timeout=3600,
Expand Down
6 changes: 3 additions & 3 deletions clip_retrieval/clip_inference/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def folder_to_keys(folder, enable_text=True, enable_image=True, enable_metadata=
image_files = None
if enable_text:
text_files = [*path.glob("**/*.txt")]
text_files = {text_file.relative_to(path).as_posix(): text_file for text_file in text_files}
text_files = {text_file.relative_to(path).with_suffix('').as_posix(): text_file for text_file in text_files}
if enable_image:
image_files = [
*path.glob("**/*.png"),
Expand All @@ -29,10 +29,10 @@ def folder_to_keys(folder, enable_text=True, enable_image=True, enable_metadata=
*path.glob("**/*.BMP"),
*path.glob("**/*.WEBP"),
]
image_files = {image_file.relative_to(path).as_posix(): image_file for image_file in image_files}
image_files = {image_file.relative_to(path).with_suffix('').as_posix(): image_file for image_file in image_files}
if enable_metadata:
metadata_files = [*path.glob("**/*.json")]
metadata_files = {metadata_file.relative_to(path).as_posix(): metadata_file for metadata_file in metadata_files}
metadata_files = {metadata_file.relative_to(path).with_suffix('').as_posix(): metadata_file for metadata_file in metadata_files}

keys = None

Expand Down
Loading