Skip to content

Fitting a multistate model

Fitting a model¤

After preparing a dataset, we have a few more configurations to do:

1) Define terminal states in the model

Optional:

2) Define a custom update function for time-varying covariates (See COVID hospitaliztion example). Default is No updating
3) Define covariate columns
4) Define state labels
5) Define minimum number of data transitions needed to fit a transition
6) Define the event specific fitters. Default is CoxWrapper. See custom_fitters for more

# Load data
from pymsm.datasets import prep_covid_hosp_data

dataset, state_labels = prep_covid_hosp_data()


# 1) Define terminal states
terminal_states = [4]

# 2) Define a custom update function for time-varying covariates.
# Default is No updating:
from pymsm.multi_state_competing_risks_model import default_update_covariates_function
update_covariates_fn = default_update_covariates_function

# Let's define one:
def covid_update_covariates_function(
    covariates_entering_origin_state,
    origin_state=None,
    target_state=None,
    time_at_origin=None,
    abs_time_entry_to_target_state=None,
):
    covariates = covariates_entering_origin_state.copy()
    # update is_severe covariate
    if origin_state == 3:
        covariates['was_severe'] = 1
    return covariates

# 3) Define covariate columns
covariate_cols = ["is_male", "age", "was_severe"]

# 4) Define state labels
state_labels_short = {0: "C", 1: "R", 2: "M", 3: "S", 4: "D"}

# 5) Define minimum number of data transitions needed to fit a transition
trim_transitions_threshold = 10

# 6) Define the event specific fitters. Default is CoxWrapper. See custom_fitters for more
from pymsm.event_specific_fitter import CoxWrapper

event_specific_fitter = CoxWrapper
100%|██████████| 2675/2675 [00:40<00:00, 65.86it/s]

# Init MultistateModel
from pymsm.multi_state_competing_risks_model import MultiStateModel

multi_state_model = MultiStateModel(
    dataset=dataset,
    terminal_states=terminal_states,
    update_covariates_fn=covid_update_covariates_function,
    covariate_names=covariate_cols,
    state_labels=state_labels_short,
    event_specific_fitter=event_specific_fitter,
    trim_transitions_threshold=trim_transitions_threshold,
)

Once we have the model initiated properly, we can fit it by simply calling the fit() method

multi_state_model.fit()
Fitting Model at State: 2
>>> Fitting Transition to State: 1, n events: 2135
>>> Fitting Transition to State: 3, n events: 275
>>> Fitting Transition to State: 4, n events: 52
Fitting Model at State: 1
>>> Fitting Transition to State: 2, n events: 98
Fitting Model at State: 3
>>> Fitting Transition to State: 2, n events: 193
>>> Fitting Transition to State: 4, n events: 135