Commit 9aff1eb1 authored by Christian Chapman-Bird's avatar Christian Chapman-Bird
Browse files

Patch fix for outdir update

parent 6b460aa7
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -53,7 +53,7 @@ def create_mlp(input_features, output_features, neurons, layers, activation, mod
    return model


def load_mlp(model_name, device, get_state_dict=False):
def load_mlp(model_name, device, get_state_dict=False, outdir='../models'):
    model = pickle.load(open(get_script_path()+f'/{outdir}/{model_name}/function.pickle', "rb"))  # load blank model
    if get_state_dict:
        model.load_state_dict(torch.load(open(get_script_path()+f'/{outdir}/{model_name}/model.pth', "rb"), map_location=device))