Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
xin-huang committed Nov 14, 2024
1 parent dc3797d commit bd8d5d2
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 210 deletions.
5 changes: 4 additions & 1 deletion examples/data/example.vcf
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,7 @@
21 999 . A T 100 PASS AA=A GT 0|0 1|1 0|0 0|0 0|0 0|0 0|0 0|0 0|0 0|0 0|0
21 1111 . A T 100 PASS AA=A GT 0|1 1|1 1|1 1|1 1|1 0|0 0|0 0|0 0|0 0|0 1|1
21 2222 . A T 100 PASS AA=A GT 0|0 0|0 0|0 0|0 0|0 0|0 0|0 0|0 0|1 1|1 1|1
21 3333 . A T 100 PASS AA=A GT 1|1 1|1 0|0 0|0 0|1 0|0 0|0 0|0 0|0 0|0 1|0
21 3333 . A T 100 PASS AA=A GT .|. 1|1 0|0 0|0 0|1 0|0 0|0 0|0 0|0 0|0 1|0
21 4444 . A T 100 PASS AA=A GT .|. 1|1 0|0 0|0 0|1 .|. 0|0 .|. 0|0 0|0 1|0
21 5555 . A T 100 PASS AA=A GT .|. 1|1 0|0 0|0 0|1 0|0 0|0 0|0 0|0 0|0 .|.
21 6666 . A T 100 PASS AA=A GT 1|1 1|1 0|0 0|0 0|1 0|0 0|0 0|0 0|0 0|0 1|0
129 changes: 28 additions & 101 deletions sai/stats/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def calc_u(
) -> int:
"""
Calculates the count of genetic loci that meet specified allele frequency conditions
across reference, target, and multiple source genotypes.
across reference, target, and multiple source genotypes, with adjustments based on src_freq consistency.
Parameters
----------
Expand All @@ -83,11 +83,6 @@ def calc_u(
-------
int
The count of loci that meet all specified frequency conditions.
Raises
------
ValueError
If `w`, `x`, or any value in `y_list` is outside the range [0, 1], or if `y_list` length does not match `src_gts_list`.
"""
# Validate input parameters
if not (0 <= w <= 1 and 0 <= x <= 1):
Expand All @@ -102,12 +97,20 @@ def calc_u(
tgt_freq = calc_freq(tgt_gts, ploidy)
src_freq_list = [calc_freq(src_gts, ploidy) for src_gts in src_gts_list]

# Set initial condition for reference and target populations
condition = (ref_freq < w) & (tgt_freq > x)
# Check if all src_freq values match y or 1 - y for each locus
all_match_y = np.all(
[src_freq == y for src_freq, y in zip(src_freq_list, y_list)], axis=0
)
all_match_1_minus_y = np.all(
[src_freq == 1 - y for src_freq, y in zip(src_freq_list, y_list)], axis=0
)

# Directly modify ref_freq and tgt_freq where src_freq matches 1 - y
ref_freq[all_match_1_minus_y] = 1 - ref_freq[all_match_1_minus_y]
tgt_freq[all_match_1_minus_y] = 1 - tgt_freq[all_match_1_minus_y]

# Add conditions for each source population
for src_freq, y in zip(src_freq_list, y_list):
condition &= src_freq == y
# Apply final condition: loci that match the adjusted conditions in ref and tgt
condition = (all_match_y | all_match_1_minus_y) & (ref_freq < w) & (tgt_freq > x)

# Count loci meeting all specified conditions
count = np.sum(condition)
Expand All @@ -125,7 +128,7 @@ def calc_q(
) -> float:
"""
Calculates a specified quantile of derived allele frequencies in `tgt_gts` for loci that meet specific conditions
across reference and multiple source genotypes.
across reference and multiple source genotypes, with adjustments based on src_freq consistency.
Parameters
----------
Expand Down Expand Up @@ -153,11 +156,6 @@ def calc_q(
float
The specified quantile of the derived allele frequencies in `tgt_gts` for loci meeting the specified conditions,
or NaN if no loci meet the criteria.
Raises
------
ValueError
If `w`, `quantile`, or any value in `y_list` is outside the range [0, 1], or if `y_list` length does not match `src_gts_list`.
"""
# Validate input parameters
if not (0 <= w <= 1 and 0 <= quantile <= 1):
Expand All @@ -172,12 +170,20 @@ def calc_q(
tgt_freq = calc_freq(tgt_gts, ploidy)
src_freq_list = [calc_freq(src_gts, ploidy) for src_gts in src_gts_list]

# Set initial condition for reference population
condition = ref_freq < w
# Check if all src_freq values match y or 1 - y for each locus
all_match_y = np.all(
[src_freq == y for src_freq, y in zip(src_freq_list, y_list)], axis=0
)
all_match_1_minus_y = np.all(
[src_freq == 1 - y for src_freq, y in zip(src_freq_list, y_list)], axis=0
)

# Add conditions for each source population
for src_freq, y in zip(src_freq_list, y_list):
condition &= src_freq == y
# Directly modify ref_freq and tgt_freq where src_freq matches 1 - y
ref_freq[all_match_1_minus_y] = 1 - ref_freq[all_match_1_minus_y]
tgt_freq[all_match_1_minus_y] = 1 - tgt_freq[all_match_1_minus_y]

# Apply the final condition based on adjusted ref_freq and tgt_freq
condition = (all_match_y | all_match_1_minus_y) & (ref_freq < w)

# Filter `tgt_gts` frequencies based on the combined condition
filtered_tgt_freq = tgt_freq[condition]
Expand All @@ -188,82 +194,3 @@ def calc_q(

# Calculate and return the specified quantile of the filtered `tgt_gts` frequencies
return np.quantile(filtered_tgt_freq, quantile)


def calc_seq_div(gts1, gts2):
"""
Calculates pairwise sequence divergence between two populations using the Hamming distance
that supports multiallelic data (e.g., values other than 0 and 1).
Parameters
----------
gts1 : np.ndarray
A 2D numpy array where each row represents a locus and each column represents an individual in the first population.
gts2 : np.ndarray
A 2D numpy array where each row represents a locus and each column represents an individual in the second population.
Returns
-------
float
The average sequence divergence between the two populations.
"""
# Expand dimensions to broadcast `gts1` and `gts2` across each other's individuals
expanded_gts1 = gts1[:, :, np.newaxis] # Shape: (loci, individuals_gts1, 1)
expanded_gts2 = gts2[:, np.newaxis, :] # Shape: (loci, 1, individuals_gts2)

# Calculate divergence for each pair by comparing values directly
div_matrix = expanded_gts1 != expanded_gts2 # True where values differ

# Average divergence across loci and individuals
pairwise_divergence = np.sum(
div_matrix, axis=0
) # Shape: (individuals_gts1, individuals_gts2)

return pairwise_divergence


def calc_rd(ref_gts, tgt_gts, src_gts):
"""
Calculates the average ratio of the sequence divergence between an individual from the source population
and an individual from the admixed population, and the sequence divergence between an individual from the
source population and an individual from the non-admixed population.
Parameters
----------
ref_gts : np.ndarray
A 2D numpy array where each row represents a locus and each column represents an individual in the non-admixed population.
tgt_gts : np.ndarray
A 2D numpy array where each row represents a locus and each column represents an individual in the admixed population.
src_gts : np.ndarray
A 2D numpy array where each row represents a locus and each column represents an individual in the source population.
Returns
-------
float
The average divergence ratio.
"""
# Step 1: Calculate sequence divergence between source and non-admixed population
divergence_src_ref = calc_seq_div(src_gts, ref_gts)

# Step 2: Calculate sequence divergence between source and admixed population
divergence_src_tgt = calc_seq_div(src_gts, tgt_gts)

if np.mean(divergence_src_ref) != 0:
return np.mean(divergence_src_tgt) / np.mean(divergence_src_ref)
else:
return np.nan

# Step 3: Replace zeros in divergence_src_ref with -1 to handle division by zero
#divergence_src_ref_safe = np.where(
# divergence_src_ref == 0, np.nan, divergence_src_ref
#)

# Step 4: Calculate the pairwise ratios
#divergence_ratios = (
# divergence_src_tgt[:, :, np.newaxis] / divergence_src_ref_safe[:, np.newaxis, :]
#)

# Step 5: Calculate the mean of the pairwise ratios
#average_divergence_ratio = np.nanmean(divergence_ratios)

#return average_divergence_ratio
9 changes: 7 additions & 2 deletions sai/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,13 @@ def read_geno_data(

# Remove missing data if specified
if filter_missing:
missing_index = chrom_data.GT.count_missing(axis=1) == len(sample_indices)
chrom_data = filter_geno_data(chrom_data, ~missing_index)
non_missing_index = chrom_data.GT.count_missing(axis=1) == 0
num_missing = len(non_missing_index) - np.sum(non_missing_index)
if num_missing != 0:
print(
f"Found {num_missing} variants with missing genotypes, removing them ..."
)
chrom_data = filter_geno_data(chrom_data, non_missing_index)

# Check and incorporate ancestral alleles if the file is provided
if anc_allele:
Expand Down
5 changes: 4 additions & 1 deletion tests/data/example.vcf
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,7 @@
21 999 . A T 100 PASS AA=A GT 0|0 1|1 0|0 0|0 0|0 0|0 0|0 0|0 0|0 0|0 0|0
21 1111 . A T 100 PASS AA=A GT 0|1 1|1 1|1 1|1 1|1 0|0 0|0 0|0 0|0 0|0 1|1
21 2222 . A T 100 PASS AA=A GT 0|0 0|0 0|0 0|0 0|0 0|0 0|0 0|0 0|1 1|1 1|1
21 3333 . A T 100 PASS AA=A GT 1|1 1|1 0|0 0|0 0|1 0|0 0|0 0|0 0|0 0|0 1|0
21 3333 . A T 100 PASS AA=A GT .|. 1|1 0|0 0|0 0|1 0|0 0|0 0|0 0|0 0|0 1|0
21 4444 . A T 100 PASS AA=A GT .|. 1|1 0|0 0|0 0|1 .|. 0|0 .|. 0|0 0|0 1|0
21 5555 . A T 100 PASS AA=A GT .|. 1|1 0|0 0|0 0|1 0|0 0|0 0|0 0|0 0|0 .|.
21 6666 . A T 100 PASS AA=A GT 1|1 1|1 0|0 0|0 0|1 0|0 0|0 0|0 0|0 0|0 1|0
90 changes: 2 additions & 88 deletions tests/stats/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from sai.stats.features import calc_u
from sai.stats.features import calc_q
from sai.stats.features import calc_freq
from sai.stats.features import calc_seq_div
from sai.stats.features import calc_rd


def test_calc_u_basic():
Expand Down Expand Up @@ -93,7 +91,7 @@ def test_calc_q_basic():

def test_calc_q_no_match():
# Test data with no matching loci
ref_gts = np.array([[0, 0, 1], [1, 1, 1]])
ref_gts = np.array([[0, 0, 1], [0, 0, 0]])
tgt_gts = np.array([[0, 1, 1], [1, 1, 1]])
src_gts = np.array([[1, 1, 1], [1, 1, 1]])
w, y, quantile = (
Expand Down Expand Up @@ -137,7 +135,7 @@ def test_calc_q_edge_case():
w, y, quantile = 0.95, 1.0, 0.95

# Expected output
expected_result = 1.0 # Only one matching site in tgt_gts
expected_result = 0.9666667

# Run test
result = calc_q(ref_gts, tgt_gts, [src_gts], w, [y], quantile)
Expand Down Expand Up @@ -244,87 +242,3 @@ def test_unphased_tetraploid_data():
decimal=6,
err_msg="Unphased tetraploid data test failed.",
)


def test_calc_seq_div():
# Test case 1: Simple case with known divergence
gts1 = np.array([[0, 1], [1, 0]])
gts2 = np.array([[1, 0], [0, 1]])
expected_divergence = np.array([[2, 0], [0, 2]])
result = calc_seq_div(gts1, gts2)
assert np.array_equal(
result, expected_divergence
), f"Failed on test case 1 with result {result}"

# Test case 2: Same populations (should result in zero divergence)
gts1 = np.array([[1, 1], [1, 1]])
gts2 = np.array([[1, 1], [1, 1]])
expected_divergence = np.array([[0, 0], [0, 0]])
result = calc_seq_div(gts1, gts2)
assert np.array_equal(
result, expected_divergence
), f"Failed on test case 2 with result {result}"

# Test case 3:
gts1 = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]])
gts2 = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]])
expected_divergence = np.array(
[
[0, 2, 2],
[2, 0, 2],
[2, 2, 0],
]
)
result = calc_seq_div(gts1, gts2)
assert np.array_equal(
result, expected_divergence
), f"Failed on test case 3 with result {result}"

# Test case 4:
gts1 = np.array(
[
[0, 1, 2],
[1, 2, 0],
[0, 2, 1],
]
)
gts2 = np.array(
[
[0, 1, 2],
[1, 2, 0],
[0, 2, 1],
]
)
expected_divergence = np.array(
[
[0, 3, 3],
[3, 0, 3],
[3, 3, 0],
]
)
result = calc_seq_div(gts1, gts2)
assert np.array_equal(
result, expected_divergence
), f"Failed on test case 4 with result {result}"


def test_calc_rd():
# Test case 1
src_gts = np.array([[0, 1], [1, 0]])
ref_gts = np.array([[1, 0], [0, 1]])
tgt_gts = np.array([[1, 1], [0, 0]])
expected_ratio = 1
result = calc_rd(ref_gts, tgt_gts, src_gts)
assert np.isclose(
result, expected_ratio
), f"Failed on test case 1 with result {result}"

# Test case 2
src_gts = np.array([[0, 1], [1, 1]])
ref_gts = np.array([[1, 0], [0, 1]])
tgt_gts = np.array([[1, 0], [0, 1]])
expected_ratio = 1.0
result = calc_rd(ref_gts, tgt_gts, src_gts)
assert np.isclose(
result, expected_ratio
), f"Failed on test case 2 with result {result}"
Loading

0 comments on commit bd8d5d2

Please sign in to comment.