-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextract_ion_images.py
82 lines (66 loc) · 2.2 KB
/
extract_ion_images.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import argparse
from pathlib import Path
import pandas as pd
from imzml2np import extract_peaks, peaks_df_to_images
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
if __name__ == "__main__":
# Parse parameters
parser = argparse.ArgumentParser(
description="""Script to extract ion images from imzml.
The script takes as input path to imzml and list of ions to annotate,
then saves corresponding ion images as a numpy array.
In addition, calculate and plot sum of given ions as an example in the end.
"""
)
parser.add_argument(
"imzml", type=str, help="imzml file to process",
)
parser.add_argument(
"metadata", type=str, help="list of ions to analyze in metaspace output format",
)
parser.add_argument(
"output_path", type=str, help="path to .npz file",
)
parser.add_argument(
"--tol", type=float, default=5, help="Tolerance in ppm at the base mz",
)
parser.add_argument(
"--base_mz", type=float, default=200, help="Base mz for mass tolerance",
)
args = parser.parse_args()
imzml_path = args.imzml
tol = args.tol
base_mz = args.base_mz
# Open metadata
metadata = pd.read_csv(args.metadata)
print(metadata)
coords_df, peaks = extract_peaks(
imzml_path, metadata, tol_ppm=tol, tol_mode="orbitrap", base_mz=base_mz
)
# Extract images
images = []
for peak in peaks:
image = peaks_df_to_images(coords_df, peak["peaks_df"])
images.append(image[1].T)
images = np.array(images)
print(f"Extracted {images.shape[0]} ion images of shape {images.shape[1:]}")
# Save as a numpy object
np.save(args.output_path, images, allow_pickle=True)
# Open numpy object
images = np.load(args.output_path)
print(f"Read numpy array of shape {images.shape}")
# Calculate and plot sum or do whatever
sum_image = np.sum(images, axis=0)
sns.heatmap(
sum_image,
cmap="magma",
cbar=True,
linewidths=0,
xticklabels=False,
yticklabels=False,
square=True,
)
plt.tight_layout()
plt.savefig("sum_image.png", dpi=300, bbox_inches="tight")