Fitting a multistate model
Fitting a model¤
After preparing a dataset, we have a few more configurations to do:
1) Define terminal states in the model
Optional:
2) Define a custom update function for time-varying covariates (See COVID hospitaliztion example). Default is No updating
3) Define covariate columns
4) Define state labels
5) Define minimum number of data transitions needed to fit a transition
6) Define the event specific fitters. Default is CoxWrapper. See custom_fitters for more
# Load data
from pymsm.datasets import prep_covid_hosp_data
dataset, state_labels = prep_covid_hosp_data()
# 1) Define terminal states
terminal_states = [4]
# 2) Define a custom update function for time-varying covariates.
# Default is No updating:
from pymsm.multi_state_competing_risks_model import default_update_covariates_function
update_covariates_fn = default_update_covariates_function
# Let's define one:
def covid_update_covariates_function(
covariates_entering_origin_state,
origin_state=None,
target_state=None,
time_at_origin=None,
abs_time_entry_to_target_state=None,
):
covariates = covariates_entering_origin_state.copy()
# update is_severe covariate
if origin_state == 3:
covariates['was_severe'] = 1
return covariates
# 3) Define covariate columns
covariate_cols = ["is_male", "age", "was_severe"]
# 4) Define state labels
state_labels_short = {0: "C", 1: "R", 2: "M", 3: "S", 4: "D"}
# 5) Define minimum number of data transitions needed to fit a transition
trim_transitions_threshold = 10
# 6) Define the event specific fitters. Default is CoxWrapper. See custom_fitters for more
from pymsm.event_specific_fitter import CoxWrapper
event_specific_fitter = CoxWrapper
# Init MultistateModel
from pymsm.multi_state_competing_risks_model import MultiStateModel
multi_state_model = MultiStateModel(
dataset=dataset,
terminal_states=terminal_states,
update_covariates_fn=covid_update_covariates_function,
covariate_names=covariate_cols,
state_labels=state_labels_short,
event_specific_fitter=event_specific_fitter,
trim_transitions_threshold=trim_transitions_threshold,
)
Once we have the model initiated properly, we can fit it by simply calling the fit() method