Skip to content

Commit

Permalink
Add BRATS expert model to support MRI image (#55)
Browse files Browse the repository at this point in the history
This PR
- adds BRATS for brain MRI segmentation.
- fix markdown link linting #53

---------

Signed-off-by: Mingxin Zheng <[email protected]>
  • Loading branch information
mingxin-zheng authored Nov 23, 2024
1 parent 948041d commit 7be688a
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 70 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/check-links.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
node-version: '18'

- name: Install markdown-link-check
run: npm install -g markdown-link-check
run: npm install -g markdown-link-check@3.12.2

- name: Check for broken links
run: find . -name "*.md" | xargs -I {} markdown-link-check {} --config .github/markdown-link-check.json
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ demo_m3:
-O $(HOME)/.torchxrayvision/models_data/pc-nih-rsna-siim-vin-resnet50-test512-e400-state.pt; \
mkdir -p $(HOME)/.cache/torch/hub/bundle \
&& python -m monai.bundle download vista3d --version 0.5.4 --bundle_dir $(HOME)/.cache/torch/hub/bundle \
&& python -m monai.bundle download brats_mri_segmentation --version 0.5.2 --bundle_dir $(HOME)/.cache/torch/hub/bundle \
&& unzip $(HOME)/.cache/torch/hub/bundle/vista3d_v0.5.4.zip -d $(HOME)/.cache/torch/hub/bundle/vista3d_v0.5.4
161 changes: 161 additions & 0 deletions m3/demo/experts/expert_monai_brats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import re
import tempfile
from pathlib import Path
from shutil import move
from uuid import uuid4

import requests
from experts.base_expert import BaseExpert
from experts.utils import get_monai_transforms, get_slice_filenames
from monai.bundle import create_workflow


class ExpertBrats(BaseExpert):
"""Expert model for BRATS."""

def __init__(self) -> None:
"""Initialize the VISTA-3D expert model."""
self.model_name = "BRATS"
self.bundle_root = os.path.expanduser("~/.cache/torch/hub/bundle/brats_mri_segmentation")

def segmentation_to_string(
self,
output_dir: Path,
img_file: str,
seg_file: str,
slice_index: int,
image_filename: str,
label_filename: str,
modality: str = "MRI",
axis: int = 2,
output_prefix="The results are <segmentation>. The colors in this image describe\n",
):
"""Convert the segmentation to a string."""
output_dir = Path(output_dir)

transforms = get_monai_transforms(
["image", "label"],
output_dir,
modality=modality,
slice_index=slice_index,
axis=axis,
image_filename=image_filename,
label_filename=label_filename,
)
data = transforms({"image": img_file, "label": seg_file})
ncr = data["colormap"].get(1, None)
ed = data["colormap"].get(2, None)
et = data["colormap"].get(4, None)
output = output_prefix
if ncr is not None and et is not None:
output += f"{ncr} and {et}: tumor core, "
if et is not None:
output += f"only {et}: enhancing tumor, "
if ncr is not None or et is not None or ed is not None:
output += "all colors: whole tumor\n"
return output

def mentioned_by(self, input: str):
"""
Check if the VISTA-3D model is mentioned in the input.
Args:
input (str): Text from the LLM, e.g. "Let me trigger <BRATS>."
Returns:
bool: True if the VISTA-3D model is mentioned, False otherwise.
"""
matches = re.findall(r"<(.*?)>", str(input))
if len(matches) != 1:
return False
return self.model_name in str(matches[0])

def download_file(self, url: str, img_file: str):
"""
Download the file from the URL.
Args:
url (str): The URL.
img_file (str): The file path.
"""
parent_dir = os.path.dirname(img_file)
os.makedirs(parent_dir, exist_ok=True)
with open(img_file, "wb") as f:
response = requests.get(url)
f.write(response.content)

def run(
self,
img_file: list[str] | None = None,
image_url: list[str] | None = None,
input: str = "",
output_dir: str = "",
slice_index: int = 0,
prompt: str = "",
**kwargs,
):
"""
Run the BRATS model.
Args:
image_url (str): The image URL list.
input (str): The input text.
output_dir (str): The output directory.
img_file (str): The image file path list. If not provided, download from the URL.
slice_index (int): The slice index.
prompt (str): The prompt text from the original request.
**kwargs: Additional keyword arguments.
"""
if not img_file:
# Download the file from the URL
for url in image_url:
img_file = os.path.join(output_dir, os.path.basename(url))
self.download_file(url, img_file)

with tempfile.TemporaryDirectory() as temp_dir:
workflow = create_workflow(
workflow_type="infer",
bundle_root=self.bundle_root,
config_file=os.path.join(self.bundle_root, f"configs/inference.json"),
logging_file=os.path.join(self.bundle_root, "configs/logging.conf"),
meta_file=os.path.join(self.bundle_root, "configs/metadata.json"),
test_datalist=[{"image": img_file}],
output_dtype="uint8",
separate_folder=False,
output_ext=".nii.gz",
output_dir=temp_dir,
)
workflow.evaluator.run()
output_file = os.path.join(temp_dir, os.listdir(temp_dir)[0])
seg_file = os.path.join(output_dir, "segmentation.nii.gz")
move(output_file, seg_file)

seg_image = f"seg_{uuid4()}.jpg"
text_output = self.segmentation_to_string(
output_dir,
img_file[0],
seg_file,
slice_index,
get_slice_filenames(img_file[0], slice_index),
seg_image,
modality="MRI",
axis=2,
)

if "segmented" in input:
instruction = "" # no need to ask for instruction
else:
instruction = "Use this result to respond to this prompt:\n" + prompt
return text_output, os.path.join(output_dir, seg_image), instruction
62 changes: 62 additions & 0 deletions m3/demo/experts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
import re
from io import BytesIO
from pathlib import Path
from shutil import copyfile, rmtree

import nibabel as nib
import numpy as np
import requests
import skimage
from monai.transforms import Compose, LoadImageD, MapTransform, OrientationD, ScaleIntensityD, ScaleIntensityRangeD
from PIL import Image
from PIL import Image as PILImage
from PIL.Image import Image
from tqdm import tqdm

logger = logging.getLogger("gradio_m3")

Expand Down Expand Up @@ -154,6 +157,8 @@ def _get_modality_url(image_url_or_path: str | None):
If the URL or file path contains ".nii.gz" and contain "mri_", then it is MRI, else it is CT.
If it contains "cxr_" then it is CXR, otherwise it is Unknown.
"""
if isinstance(image_url_or_path, list) and len(image_url_or_path) > 0:
image_url_or_path = image_url_or_path[0]
if not isinstance(image_url_or_path, str):
return "Unknown"
if image_url_or_path.startswith("data:image"):
Expand Down Expand Up @@ -339,3 +344,60 @@ def resize_data_url(data_url, max_size):
img_base64 = base64.b64encode(img_byte).decode()
# Convert the base64 bytes to string and format the data URL
return f"data:image/jpeg;base64,{img_base64}"


class ImageCache:
"""A simple image cache to store images and data URLs."""

def __init__(self, cache_dir: Path):
"""Initialize the image cache."""
cache_dir = Path(cache_dir)
if not cache_dir.exists():
cache_dir.mkdir(parents=True)
self.cache_dir = cache_dir
self.cache_images = {}

def cache(self, image_urls_or_paths):
"""Cache the images from the URLs or paths."""
logger.debug(f"Caching the image to {self.cache_dir}")
for _, items in image_urls_or_paths.items():
items = items if isinstance(items, list) else [items]
for item in items:
if item.startswith("http"):
self.cache_images[item] = save_image_url_to_file(item, self.cache_dir)
elif os.path.exists(item):
# move the file to the cache directory
file_name = os.path.basename(item)
self.cache_images[item] = os.path.join(self.cache_dir, file_name)
if not os.path.isfile(self.cache_images[item]):
copyfile(item, self.cache_images[item])

if self.cache_images[item].endswith(".nii.gz"):
data = nib.load(self.cache_images[item]).get_fdata()
for slice_index in tqdm(range(data.shape[2])):
image_filename = get_slice_filenames(self.cache_images[item], slice_index)
if not os.path.exists(os.path.join(self.cache_dir, image_filename)):
compose = get_monai_transforms(
["image"],
self.cache_dir,
modality=get_modality(item),
slice_index=slice_index,
image_filename=image_filename,
)
compose({"image": self.cache_images[item]})

def cleanup(self):
"""Clean up the cache directory."""
logger.debug(f"Cleaning up the cache")
rmtree(self.cache_dir)

def dir(self):
"""Return the cache directory."""
return str(self.cache_dir)

def get(self, key: str | list, default=None, list_return=False):
"""Get the image or data URL from the cache."""
if isinstance(key, list):
items = [self.cache_images.get(k) for k in key]
return items if list_return else items[0]
return self.cache_images.get(key, default)
Loading

0 comments on commit 7be688a

Please sign in to comment.