Skip to content

Latest commit

 

History

History
86 lines (56 loc) · 2.09 KB

README.md

File metadata and controls

86 lines (56 loc) · 2.09 KB

PromptEHR

PyPI version Downloads GitHub Repo stars GitHub Repo forks

Wang, Zifeng and Sun, Jimeng. (2022). PromptEHR: Conditional Electronic Healthcare Records Generation with Prompt Learning. EMNLP'22.

Usage

Get pretrained PromptEHR model (learned on MIMIC-III sequence EHRs) in three lines:

from promptehr import PromptEHR

model = PromptEHR()

model.from_pretrained()

A jupyter example is available at https://github.com/RyanWangZf/PromptEHR/blob/main/example/demo_promptehr.ipynb.

How to install

Install the correct PyTorch version by referring to https://pytorch.org/get-started/locally/.

Then try to install PromptEHR by

pip install git+https://github.com/RyanWangZf/PromptEHR.git

or

pip install promptehr

Load demo synthetic EHRs (generated by PromptEHR)

from promptehr import load_synthetic_data
data = load_synthetic_data()

Use PromptEHR for generation

from promptehr import SequencePatient
from promptehr import load_synthetic_data
from promptehr import PromptEHR

# init model
model = PromptEHR()
model.from_pretrained()

# load input data
demo = load_synthetic_data(n_sample=1000) # we have 10,000 samples in total

# build the standard input data for train or test PromptEHR models
seqdata = SequencePatient(data={'v':demo['visit'], 'y':demo['y'], 'x':demo['feature'],},
    metadata={
        'visit':{'mode':'dense'},
        'label':{'mode':'tensor'}, 
        'voc':demo['voc'],
        'max_visit':20,
        }
    )
# you can try to fit on this data by
# model.fit(seqdata)

# start generate
# n: the target total number of samples to generate
# n_per_sample: based on each sample, how many fake samples will be generated
# the output will have the same format of `SequencePatient`
fake_data = model.predict(seqdata, n=1000, n_per_sample=10)