Loading EMRI_DET/nn/model_creation.py +1 −1 Original line number Diff line number Diff line Loading @@ -53,7 +53,7 @@ def create_mlp(input_features, output_features, neurons, layers, activation, mod return model def load_mlp(model_name, device, get_state_dict=False): def load_mlp(model_name, device, get_state_dict=False, outdir='../models'): model = pickle.load(open(get_script_path()+f'/{outdir}/{model_name}/function.pickle', "rb")) # load blank model if get_state_dict: model.load_state_dict(torch.load(open(get_script_path()+f'/{outdir}/{model_name}/model.pth', "rb"), map_location=device)) Loading Loading
EMRI_DET/nn/model_creation.py +1 −1 Original line number Diff line number Diff line Loading @@ -53,7 +53,7 @@ def create_mlp(input_features, output_features, neurons, layers, activation, mod return model def load_mlp(model_name, device, get_state_dict=False): def load_mlp(model_name, device, get_state_dict=False, outdir='../models'): model = pickle.load(open(get_script_path()+f'/{outdir}/{model_name}/function.pickle', "rb")) # load blank model if get_state_dict: model.load_state_dict(torch.load(open(get_script_path()+f'/{outdir}/{model_name}/model.pth', "rb"), map_location=device)) Loading