Multistate model
This class fits a competing risks model per state, that is, it treats all state transitions as competing risks. See the CompetingRisksModel class
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
Union[List[PathObject], DataFrame] |
either a list of PathObject or a pandas data frame in the format used to fit the CompetingRiskModel class. Dataset used to fit a competing risk model to each state |
required |
terminal_states |
List[int] |
States which a sample does not leave |
required |
update_covariates_fn |
Callable[ [Series, int, int, float, float], Series ] |
A state-transition function, which updates the time dependent variables. This function is used in fitting the model so that the user doesn't have to manually compute the features at each state, it is also used in the monte carlo method of predicting the survival curves per sample. Defaults to default_update_covariates_function. |
<function default_update_covariates_function at 0x7fb325cdaca0> |
covariate_names |
List[str] |
Optional list of covariate names to be used. Defaults to None. |
None |
event_specific_fitter |
EventSpecificFitter |
This class holds the model that will be fitter inside the CompetingRisksModel. Defaults to CoxWrapper. |
<class 'pymsm.event_specific_fitter.CoxWrapper'> |
competing_risk_data_format |
bool |
A boolean indicating the format of the dataset parmeter, if False - the dataset is assumed to be a list of PathObjects, if True - the dataset is assumed to be a dataframe which is compatible in format for fitting the CompetingRiskModel class. Defaults to False. |
False |
states_labels |
Dict[int, str] |
A dictionary of short state labels. Defaults to None. |
required |
trim_transitions_threshold |
int |
A threshold of the minimal number of transitions needed in the data to build a tranisition model. For transitions with less this number, no moedl will be built and data will be discarded. Defaults to 0. |
0 |
Attributes:
Name | Type | Description |
---|---|---|
state_specific_models |
Dict[int, CompetingRisksModel] |
A dictionary of CompetingRisksModel objects, one for each state. Available after running the "fit" function. |
Note
The update_covariates_fn could be any function you choose to write, but it needs to have the following parameter types (in this order): pandas Series, int, int, float, float; and return a pandas Series.
Source code in pymsm/multi_state_competing_risks_model.py
class MultiStateModel:
"""This class fits a competing risks model per state, that is, it treats all state transitions as competing risks. See the CompetingRisksModel class
Args:
dataset (Union[List[PathObject], DataFrame]): either a list of PathObject or a pandas data frame in the format used to fit the CompetingRiskModel class. Dataset used to fit a competing risk model to each state
terminal_states (List[int]): States which a sample does not leave
update_covariates_fn (Callable[ [Series, int, int, float, float], Series ], optional): A state-transition function, which updates the time dependent variables. This function is used in fitting the model so that the user doesn't have to manually compute the features at each state, it is also used in the monte carlo method of predicting the survival curves per sample. Defaults to default_update_covariates_function.
covariate_names (List[str], optional): Optional list of covariate names to be used. Defaults to None.
event_specific_fitter (EventSpecificFitter, optional): This class holds the model that will be fitter inside the CompetingRisksModel. Defaults to CoxWrapper.
competing_risk_data_format (bool, optional): A boolean indicating the format of the dataset parmeter, if False - the dataset is assumed to be a list of PathObjects, if True - the dataset is assumed to be a dataframe which is compatible in format for fitting the CompetingRiskModel class. Defaults to False.
states_labels (Dict[int, str], optional): A dictionary of short state labels. Defaults to None.
trim_transitions_threshold (int): A threshold of the minimal number of transitions needed in the data to build a tranisition model. For transitions with less this number, no moedl will be built and data will be discarded. Defaults to 0.
Attributes:
state_specific_models (Dict[int, CompetingRisksModel]): A dictionary of CompetingRisksModel objects, one for each state. Available after running the "fit" function.
Note:
The update_covariates_fn could be any function you choose to write, but it needs to have the following parameter
types (in this order): pandas Series, int, int, float, float; and return a pandas Series.
"""
def __init__(
self,
dataset: Union[List[PathObject], DataFrame],
terminal_states: List[int],
update_covariates_fn: Callable[
[Series, int, int, float, float], Series
] = default_update_covariates_function,
covariate_names: List[str] = None,
event_specific_fitter: EventSpecificFitter = CoxWrapper,
competing_risk_data_format: bool = False,
state_labels: Dict[int, str] = None,
trim_transitions_threshold: int = 0,
):
self.dataset = dataset
self.terminal_states = terminal_states
self.update_covariates_fn = update_covariates_fn
self.covariate_names = self._get_covariate_names(covariate_names)
self.state_specific_models: Dict[int, CompetingRisksModel] = dict()
self._time_is_discrete: bool = None
self.competing_risk_dataset: DataFrame = None
self._samples_have_weights: bool = False
self._competing_risk_data_format = competing_risk_data_format
self._event_specific_fitter = event_specific_fitter
self.state_labels = state_labels
self.trim_transitions_threshold = trim_transitions_threshold
self.transition_matrix: DataFrame = None
self.transition_table: DataFrame = None
self.state_diagram_graph_string: str = None
if self._competing_risk_data_format:
self.competing_risk_dataset = dataset
if self.trim_transitions_threshold > 0:
self._trim_transitions()
else:
self._assert_valid_input()
def fit(self, verbose: int = 1) -> None:
"""Fit a CompetingRiskModel for each state
Args:
verbose (int, optional): verbosity, by default 1. Defaults to 1.
"""
self.competing_risk_dataset = (
self.dataset
if self._competing_risk_data_format
else self._prepare_dataset_for_competing_risks_fit()
)
self._time_is_discrete = self._check_if_time_is_discrete()
for state in self.competing_risk_dataset["origin_state"].unique():
if verbose >= 1:
print("Fitting Model at State: {}".format(state))
model = self._fit_state_specific_model(state, verbose)
self.state_specific_models[state] = model
if verbose >= 1:
self.plot_state_diagram()
def _assert_valid_input(self) -> None:
"""Checks that the dataset is valid for running the multi state competing risk model"""
# Check the number of times is either equal or one less than the number of states
for obj in self.dataset:
n_states = len(obj.states)
n_times = len(obj.time_at_each_state)
assert n_states == n_times or n_states == n_times + 1
if n_states == 1 and obj.states[0] in self.terminal_states:
obj.print_path()
exit(
"Error: encountered a sample with a single state that is a terminal state."
)
# Check either all objects have an id or none has
has_id = [obj for obj in self.dataset if obj.sample_id is not None]
assert len(has_id) == len(self.dataset) or len(has_id) == 0
# Check either all objects have sample weight or none has
has_weight = [obj for obj in self.dataset if obj.sample_weight is not None]
assert len(has_weight) == len(self.dataset) or len(has_weight) == 0
self._samples_have_weights = True if len(has_weight) > 0 else False
# Check all covariates are of the same length
cov_len = len(self.dataset[0].covariates)
same_length = [obj for obj in self.dataset if len(obj.covariates) == cov_len]
assert len(same_length) == len(self.dataset)
# Check length of covariate names matches the length of covariates in PathObject
if self.covariate_names is None:
return
assert len(self.covariate_names) == len(self.dataset[0].covariates)
def _get_covariate_names(self, covariate_names: List[str]) -> List[str]:
"""This functions sets the covariate names that will be used in prints.
Names are taken either from the given covariate names provided by the user,
or if None provided - from the named pandas Series of covariates of the PathObject in the dataset
Args:
covariate_names (List[str], optional): covariate names provided in class init
Returns:
List: List of covariate names
"""
if covariate_names is not None:
return covariate_names
return self.dataset[0].covariates.index.to_list()
def _check_if_time_is_discrete(self) -> bool:
"""This function check whether the time in the dataset is discrete"""
times = (
self.competing_risk_dataset["time_entry_to_origin"].values.tolist()
+ self.competing_risk_dataset["time_transition_to_target"].values.tolist()
)
if all(isinstance(t, int) for t in times):
return True
return False
def _prepare_dataset_for_competing_risks_fit(self) -> DataFrame:
"""This function converts the given dataset (list of PathObjects) to a pandas DataFrame that will be used when
fitting the CompetingRiskModel class
"""
self.competing_risk_dataset = DataFrame()
for obj in self.dataset:
origin_state = obj.states[0]
covs_entering_origin = Series(
dict(zip(self.covariate_names, obj.covariates.values))
)
time_entry_to_origin = 0
for i, state in enumerate(obj.states):
transition_row = {}
time_in_origin = obj.time_at_each_state[i]
time_transition_to_target = time_entry_to_origin + time_in_origin
target_state = (
obj.states[i + 1] if i + 1 < len(obj.states) else RIGHT_CENSORING
)
# append row corresponding to this transition
transition_row["sample_id"] = obj.sample_id
if self._samples_have_weights:
transition_row["sample_weight"] = obj.sample_weight
transition_row["origin_state"] = origin_state
transition_row["target_state"] = target_state
transition_row["time_entry_to_origin"] = time_entry_to_origin
transition_row["time_transition_to_target"] = time_transition_to_target
transition_row.update(covs_entering_origin.to_dict())
self.competing_risk_dataset = pd.concat(
[self.competing_risk_dataset, pd.DataFrame([transition_row])],
ignore_index=True,
)
if (
target_state == RIGHT_CENSORING
or target_state in self.terminal_states
):
break
else:
# Set up for the next iteration
covs_entering_origin = self.update_covariates_fn(
covs_entering_origin, origin_state, target_state, time_in_origin
)
origin_state = target_state
time_entry_to_origin = time_transition_to_target
self.competing_risk_dataset["sample_id"] = self.competing_risk_dataset[
"sample_id"
].astype(int)
self.competing_risk_dataset["origin_state"] = self.competing_risk_dataset[
"origin_state"
].astype(int)
self.competing_risk_dataset["target_state"] = self.competing_risk_dataset[
"target_state"
].astype(int)
# Create default state_labels if None provided
if self.state_labels is None:
unique_states = pd.unique(
self.competing_risk_dataset[
["origin_state", "target_state"]
].values.ravel("K")
) # find unique possible states
unique_states.sort() # sort
unique_states = unique_states[unique_states > 0] # drop censored
self.state_labels = dict(
zip(unique_states, [str(s) for s in unique_states])
)
# trim transitions if needed
if self.trim_transitions_threshold > 0:
self._trim_transitions()
return self.competing_risk_dataset
def prep_transition_table(self):
"""This function creates a transition matrix from the dataset"""
if self.competing_risk_dataset is None:
self._prepare_dataset_for_competing_risks_fit()
# Create transition matrix
self.transition_matrix = pd.crosstab(
self.competing_risk_dataset["origin_state"],
self.competing_risk_dataset["target_state"],
)
# Rename rows and columns and get a transition table
self.transition_table = self.transition_matrix.copy()
rename_dict = self.state_labels.copy()
rename_dict[0] = "Censored"
self.transition_table.rename(columns=rename_dict, inplace=True)
self.transition_table.rename(index=rename_dict, inplace=True)
return self.transition_table
def _trim_transitions(self) -> None:
"""For transitions with less than trim_transitions_threshold - discard data in competing_risk_dataset"""
# copy orignal dataset for any future use
self._original_competing_risk_dataset = self.competing_risk_dataset.copy()
if self.transition_matrix is None:
self.prep_transition_table()
# generate a list of tuples of transitions with less than trim_transitions_threshold
i, j = np.where(self.transition_matrix < self.trim_transitions_threshold)
invalid_transitions = list(
zip(
self.transition_matrix.index[i].values,
self.transition_matrix.columns[j].values,
)
)
# discard these tranisitons from competing risk dataset
for origin_state, target_state in invalid_transitions:
if origin_state == target_state:
continue
self.competing_risk_dataset = self.competing_risk_dataset[
~(
(self.competing_risk_dataset["origin_state"] == origin_state)
& (self.competing_risk_dataset["target_state"] == target_state)
)
]
# Update transition_table
self.prep_transition_table()
def extract_state_diagram_string_from_transition_table(self) -> str:
"""This function extracts a mermaid state diagram string"""
if self.transition_table is None:
self.prep_transition_table()
graph = """stateDiagram-v2\n"""
for s, state_label in self.state_labels.items():
graph += f"""s{s} : ({s}) {state_label}\n"""
for origin_state, row in self.transition_matrix.iterrows():
for target_state in row.index:
if target_state == 0: # Censored transition
continue
if row[target_state] == 0: # Empty transition
continue
num_transitions = row[target_state]
graph += (
f"""s{origin_state} --> s{target_state}: {num_transitions} \n"""
)
graph += """\n"""
self.state_diagram_graph_string = graph
def plot_state_diagram(self):
"""This function plots a mermaid state diagram for the model"""
if self.state_diagram_graph_string is None:
self.extract_state_diagram_string_from_transition_table()
return state_diagram(self.state_diagram_graph_string)
def _fit_state_specific_model(
self, state: int, verbose: int = 1
) -> CompetingRisksModel:
"""Fit a CompetingRiskModel for a specific given state
Args:
state (int): State to fit the model for
verbose (int, optional): verbosity. Defaults to 1.
Returns:
CompetingRisksModel: state specific model
"""
state_specific_df = self.competing_risk_dataset[
self.competing_risk_dataset["origin_state"] == state
].copy()
state_specific_df.drop(["origin_state"], axis=1, inplace=True)
state_specific_df.reset_index(drop=True, inplace=True)
crm = CompetingRisksModel(self._event_specific_fitter)
crm.fit(
state_specific_df,
event_col="target_state",
duration_col="time_transition_to_target",
cluster_col="sample_id",
entry_col="time_entry_to_origin",
verbose=verbose,
)
return crm
def _assert_valid_simulation_input(
self,
sample_covariates: np.ndarray,
origin_state: int,
current_time: int,
n_random_samples: int,
max_transitions: int,
n_jobs: int,
print_paths: bool,
):
"""This function checks if the input to the simulation is valid."""
# TODO assert valid inputs for sample_covariates (Series or np.ndarray, and length), origin_state
assert current_time >= 0
assert isinstance(n_random_samples, int)
assert n_random_samples > 0
assert isinstance(max_transitions, int)
assert max_transitions > 0
assert isinstance(print_paths, bool)
def run_monte_carlo_simulation(
self,
sample_covariates: np.ndarray, # TODO change to np.ndarray OR pd.Series
origin_state: int,
current_time: int = 0,
n_random_samples: int = 100,
max_transitions: int = 10,
n_jobs: int = -1,
print_paths: bool = False,
) -> List[PathObject]:
"""This function samples random paths using Monte Carlo simulation.
These paths will be used for prediction for a single sample.
Initial sample covariates, along with the sample’s current state are supplied.
The next states are sequentially sampled via the model parameters.
The process concludes when the sample arrives at a terminal state or the number of transitions exceeds the
specified maximum.
Args:
sample_covariates (np.ndarray): Initial sample covariates, when entering the origin state
origin_state (int): Initial state where the path begins from
current_time (int, optional): Time when starting the sample path. Defaults to 0.
n_random_samples (int, optional): Number of random paths to create. Defaults to 100.
max_transitions (int, optional): Max number of transitions to allow in the paths. Defaults to 10.
n_jobs (int, optional): Number of parallel jobs to run. Defaults to -1.
print_paths (bool, optional): Whether to print the paths or not. Defaults to False.
Returns:
List[PathObject]: list of length n_random_samples, contining the randomly create PathObjects
"""
# Check input is valid
self._assert_valid_simulation_input(
sample_covariates,
origin_state,
current_time,
n_random_samples,
max_transitions,
n_jobs,
print_paths,
)
if n_jobs is None: # no parallelization
runs = []
for i in tqdm(range(0, n_random_samples)):
runs.append(
self._one_monte_carlo_run(
sample_covariates, origin_state, max_transitions, current_time
)
)
else: # Run parallel jobs
runs = Parallel(n_jobs=n_jobs)(
delayed(self._one_monte_carlo_run)(
sample_covariates, origin_state, max_transitions, current_time
)
for i in tqdm(range(0, n_random_samples))
)
if print_paths:
self._print_paths(runs)
return runs
def _one_monte_carlo_run(
self,
sample_covariates: np.ndarray,
origin_state: int,
max_transitions: int,
current_time: int = 0,
) -> PathObject:
"""This function create one path using Monte Carlo Simulations.
See documentation of run_monte_carlo_simulation.
"""
run = PathObject(states=list(), time_at_each_state=list())
run.stopped_early = False
current_state = origin_state
for i in range(0, max_transitions):
next_state = self._sample_next_state(
current_state, sample_covariates, current_time
)
if next_state is None:
run.stopped_early = True
return run
time_to_next_state = self._sample_time_to_next_state(
current_state, next_state, sample_covariates, current_time
)
run.states.append(current_state)
run.time_at_each_state.append(time_to_next_state)
if next_state in self.terminal_states:
run.states.append(next_state)
break
else:
time_entry_to_target = current_state + time_to_next_state
sample_covariates = self.update_covariates_fn(
sample_covariates,
current_state,
next_state,
time_to_next_state,
time_entry_to_target,
)
current_state = next_state
current_time = current_time + time_to_next_state
return run
def _probability_for_next_state(
self,
next_state: int,
competing_risks_model: CompetingRisksModel,
sample_covariates: np.ndarray,
t_entry_to_current_state: int = 0,
):
"""This function calculates the probability of transition to the next state, using the competing_risks_model
model parameters
"""
unique_event_times = competing_risks_model.unique_event_times(next_state)
if self._time_is_discrete:
mask = unique_event_times > np.floor(t_entry_to_current_state + 1)
else:
mask = unique_event_times > t_entry_to_current_state
# hazard for the failure type corresponding to 'state':
hazard = competing_risks_model.hazard_at_unique_event_times(
sample_covariates, next_state
)
hazard = hazard[mask]
# overall survival function evaluated at time of failures corresponding to 'state'
survival = competing_risks_model.survival_function(
unique_event_times[mask], sample_covariates
)
probability_for_state = np.nansum(hazard * survival)
return probability_for_state
def _sample_next_state(
self,
current_state: int,
sample_covariates: np.ndarray,
t_entry_to_current_state: int,
) -> Optional[int]:
"""This function samples the next state, according to a multinomial distribution, using probabilites defines
by _probability_for_next_state function.
"""
competing_risk_model = self.state_specific_models[current_state]
possible_next_states = competing_risk_model.failure_types
# compute probabilities for multinomial distribution
probabilites = {}
for state in possible_next_states:
probabilites[state] = self._probability_for_next_state(
state, competing_risk_model, sample_covariates, t_entry_to_current_state
)
# when no transition after t_entry_to_current_state was seen
if all(value == 0 for value in probabilites.values()):
return None
mult = np.random.multinomial(1, list(probabilites.values()))
next_state = possible_next_states[mult.argmax()]
return next_state
def _sample_time_to_next_state(
self,
current_state: int,
next_state: int,
sample_covariates: np.ndarray,
t_entry_to_current_state: int,
) -> float:
"""This function samples the time of transition to the next state, using the hazard and survival provided by
the competing risk model of the current_state
"""
competing_risk_model = self.state_specific_models[current_state]
unique_event_times = competing_risk_model.unique_event_times(next_state)
# ensure discrete variables are sampled from the next time unit
if self._time_is_discrete:
mask = unique_event_times > np.floor(t_entry_to_current_state + 1)
else:
mask = unique_event_times > t_entry_to_current_state
unique_event_times = unique_event_times[mask]
# hazard for the failure type corresponding to 'state'
hazard = competing_risk_model.hazard_at_unique_event_times(
sample_covariates, next_state
)
hazard = hazard[mask]
# overall survival function evaluated at time of failures corresponding to 'state'
survival = competing_risk_model.survival_function(
unique_event_times, sample_covariates
)
probability_for_each_t = np.nancumsum(hazard * survival)
probability_for_each_t_given_next_state = (
probability_for_each_t / probability_for_each_t.max()
)
# take the first event time whose probability is less than or equal to eps
# if we drew a very small eps, use the minimum observed time
eps = np.random.uniform(size=1)
possible_times = np.concatenate(
(
unique_event_times[probability_for_each_t_given_next_state <= eps],
[unique_event_times[0]],
)
)
time_to_next_state = possible_times.max()
time_to_next_state = time_to_next_state - t_entry_to_current_state
return time_to_next_state
def _print_paths(self, mc_paths):
"""Helper function for printing the paths of a Monte Carlo simulation"""
for mc_path in mc_paths:
mc_path.print_path()
print("\n")
extract_state_diagram_string_from_transition_table(self)
¤
This function extracts a mermaid state diagram string
Source code in pymsm/multi_state_competing_risks_model.py
def extract_state_diagram_string_from_transition_table(self) -> str:
"""This function extracts a mermaid state diagram string"""
if self.transition_table is None:
self.prep_transition_table()
graph = """stateDiagram-v2\n"""
for s, state_label in self.state_labels.items():
graph += f"""s{s} : ({s}) {state_label}\n"""
for origin_state, row in self.transition_matrix.iterrows():
for target_state in row.index:
if target_state == 0: # Censored transition
continue
if row[target_state] == 0: # Empty transition
continue
num_transitions = row[target_state]
graph += (
f"""s{origin_state} --> s{target_state}: {num_transitions} \n"""
)
graph += """\n"""
self.state_diagram_graph_string = graph
fit(self, verbose=1)
¤
Fit a CompetingRiskModel for each state
Parameters:
Name | Type | Description | Default |
---|---|---|---|
verbose |
int |
verbosity, by default 1. Defaults to 1. |
1 |
Source code in pymsm/multi_state_competing_risks_model.py
def fit(self, verbose: int = 1) -> None:
"""Fit a CompetingRiskModel for each state
Args:
verbose (int, optional): verbosity, by default 1. Defaults to 1.
"""
self.competing_risk_dataset = (
self.dataset
if self._competing_risk_data_format
else self._prepare_dataset_for_competing_risks_fit()
)
self._time_is_discrete = self._check_if_time_is_discrete()
for state in self.competing_risk_dataset["origin_state"].unique():
if verbose >= 1:
print("Fitting Model at State: {}".format(state))
model = self._fit_state_specific_model(state, verbose)
self.state_specific_models[state] = model
if verbose >= 1:
self.plot_state_diagram()
plot_state_diagram(self)
¤
This function plots a mermaid state diagram for the model
Source code in pymsm/multi_state_competing_risks_model.py
prep_transition_table(self)
¤
This function creates a transition matrix from the dataset
Source code in pymsm/multi_state_competing_risks_model.py
def prep_transition_table(self):
"""This function creates a transition matrix from the dataset"""
if self.competing_risk_dataset is None:
self._prepare_dataset_for_competing_risks_fit()
# Create transition matrix
self.transition_matrix = pd.crosstab(
self.competing_risk_dataset["origin_state"],
self.competing_risk_dataset["target_state"],
)
# Rename rows and columns and get a transition table
self.transition_table = self.transition_matrix.copy()
rename_dict = self.state_labels.copy()
rename_dict[0] = "Censored"
self.transition_table.rename(columns=rename_dict, inplace=True)
self.transition_table.rename(index=rename_dict, inplace=True)
return self.transition_table
run_monte_carlo_simulation(self, sample_covariates, origin_state, current_time=0, n_random_samples=100, max_transitions=10, n_jobs=-1, print_paths=False)
¤
This function samples random paths using Monte Carlo simulation. These paths will be used for prediction for a single sample. Initial sample covariates, along with the sample’s current state are supplied. The next states are sequentially sampled via the model parameters. The process concludes when the sample arrives at a terminal state or the number of transitions exceeds the specified maximum.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sample_covariates |
np.ndarray |
Initial sample covariates, when entering the origin state |
required |
origin_state |
int |
Initial state where the path begins from |
required |
current_time |
int |
Time when starting the sample path. Defaults to 0. |
0 |
n_random_samples |
int |
Number of random paths to create. Defaults to 100. |
100 |
max_transitions |
int |
Max number of transitions to allow in the paths. Defaults to 10. |
10 |
n_jobs |
int |
Number of parallel jobs to run. Defaults to -1. |
-1 |
print_paths |
bool |
Whether to print the paths or not. Defaults to False. |
False |
Returns:
Type | Description |
---|---|
List[PathObject] |
list of length n_random_samples, contining the randomly create PathObjects |
Source code in pymsm/multi_state_competing_risks_model.py
def run_monte_carlo_simulation(
self,
sample_covariates: np.ndarray, # TODO change to np.ndarray OR pd.Series
origin_state: int,
current_time: int = 0,
n_random_samples: int = 100,
max_transitions: int = 10,
n_jobs: int = -1,
print_paths: bool = False,
) -> List[PathObject]:
"""This function samples random paths using Monte Carlo simulation.
These paths will be used for prediction for a single sample.
Initial sample covariates, along with the sample’s current state are supplied.
The next states are sequentially sampled via the model parameters.
The process concludes when the sample arrives at a terminal state or the number of transitions exceeds the
specified maximum.
Args:
sample_covariates (np.ndarray): Initial sample covariates, when entering the origin state
origin_state (int): Initial state where the path begins from
current_time (int, optional): Time when starting the sample path. Defaults to 0.
n_random_samples (int, optional): Number of random paths to create. Defaults to 100.
max_transitions (int, optional): Max number of transitions to allow in the paths. Defaults to 10.
n_jobs (int, optional): Number of parallel jobs to run. Defaults to -1.
print_paths (bool, optional): Whether to print the paths or not. Defaults to False.
Returns:
List[PathObject]: list of length n_random_samples, contining the randomly create PathObjects
"""
# Check input is valid
self._assert_valid_simulation_input(
sample_covariates,
origin_state,
current_time,
n_random_samples,
max_transitions,
n_jobs,
print_paths,
)
if n_jobs is None: # no parallelization
runs = []
for i in tqdm(range(0, n_random_samples)):
runs.append(
self._one_monte_carlo_run(
sample_covariates, origin_state, max_transitions, current_time
)
)
else: # Run parallel jobs
runs = Parallel(n_jobs=n_jobs)(
delayed(self._one_monte_carlo_run)(
sample_covariates, origin_state, max_transitions, current_time
)
for i in tqdm(range(0, n_random_samples))
)
if print_paths:
self._print_paths(runs)
return runs