import os
import sys
from datetime import datetime
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
sys.path.append(os.path.abspath("/home/jg3837/DataFed_TorchFlow/DataFed_TorchFlow/src"))
import getpass
import json
import logging
import pathlib
import traceback
import types
import numpy as np
from datafed.CommandLib import API
from m3util.globus.globus import check_globus_file_access
from m3util.util.IO import find_files_recursive, make_folder
from tqdm import tqdm
from datafed_torchflow.computer import get_system_info
from datafed_torchflow.datafed import DataFed
from datafed_torchflow.utils import (
extract_instance_attributes,
getNotebookMetadata,
serialize_model,
serialize_pytorch_optimizer,
)
# TODO: Add data and dataloader derivative.
# TODO: Add number of FLOPS to metadata
[docs]
class TorchLogger:
"""
TorchLogger is a class designed to log PyTorch model training details,
including model architecture, optimizer state, and system information.
It also integrates with the DataFed API for file and metadata management.
Attributes:
model_dict (dict): a dictionary containing the Pytorch model architecture to be logged,
with the name of the block as the key and the block as the value.
For example: {"vae":vae, "encoder: encoder, "decoder":decoder,"optimizer":optimizer}
DataFed_path (str): The path to the DataFed configuration or API.
script_path (str): Path to the script or notebook for checksum calculation.
local_model_path (str): Local directory to store model files.
input_data_shape (tuple): Shape of the input training data for the model.
logging (bool): Whether to display logging output.
"""
def __init__(
self,
model_dict,
DataFed_path,
script_path=None,
local_model_path="/.",
log_file_path="log.txt",
input_data_shape=None,
dataset_id_or_path=None,
logging=False,
download_kwargs={"wait": True, "orig_fname": True},
):
"""
Initializes the TorchLogger class.
Args:
model_dict (dict): a dictionary containing the Pytorch model architecture to be logged,
with the name of the block as the key and the block as the value.
For example: {"vae":vae, "encoder: encoder, "decoder":decoder,"optimizer":optimizer}
DataFed_path (str): Path to the DataFed configuration or API.
script_path (str, optional): Path to the script or notebook. Default is None.
local_model_path (str, optional): Local directory to store model files. Default is './'.
log_file_path (str, optional): Local file to store a log of the code evaluation. Default is 'log.txt'
input_data (numpy.ndarray, default=None): Input data for training the model.
dataset_id (str, default=None): DataFed ID for the input dataset for the model
logging (bool, optional): Flag for logging output. Default is False.
optimizer (torch.optim.Optimizer, optional): The optimizer used for training. Default is None.
"""
self.current_checkpoint_id = None
self.notebook_record_id = None
self.__file__ = script_path
self.model_dict = model_dict
if "optimizer" in model_dict.keys():
self.optimizer = self.model_dict["optimizer"]
self.DataFed_path = DataFed_path
self.local_model_path = local_model_path
self.log_file_path = log_file_path
self.dataset_id_or_path = dataset_id_or_path
self.download_kwargs = download_kwargs
self.logging = logging
self.input_data_shape = input_data_shape
make_folder(self.local_model_path)
self.df_api = DataFed(
self.DataFed_path,
self.local_model_path,
log_file_path=self.log_file_path,
dataset_id_or_path=self.dataset_id_or_path,
download_kwargs=self.download_kwargs,
logging=True,
)
# Check if Globus has access to the local path
check_globus_file_access(
self.df_api.endpointDefaultGet(), self.local_model_path
)
# Save the notebook to DataFed
self.save_notebook()
[docs]
def reset(self):
self.current_checkpoint_id = None
@property
def optimizer(self):
"""
Returns the optimizer used for training.
Returns:
torch.optim.Optimizer: The optimizer instance.
"""
return self._optimizer
@optimizer.setter
def optimizer(self, optimizer):
"""
Sets the optimizer used for training.
Args:
optimizer (torch.optim.Optimizer): The optimizer to be set.
"""
self._optimizer = optimizer
[docs]
def getModelArchitectureStateDict(self):
"""
generates a dictionary where the key is the model architecture block
and the value is the corresponding state dictionary to go in the saved checkpoint,
for example
Returns:
dict: A dictionary containing the model architecture state dictionaries
"""
model_architecture = {}
for block in self.model_dict.keys():
model_architecture[block] = self.model_dict[block].state_dict()
return model_architecture
[docs]
def getUserClock(self):
"""
Gathers system information including CPU, memory, and GPU details.
Returns:
dict: A dictionary containing system information.
"""
# Get the current user
current_user = getpass.getuser()
# Get the current time
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
return current_user, current_time
[docs]
def save_notebook(self):
"""
Saves the Jupyter notebook that runs the code training the model
"""
# don't upload the notebook to DataFed if it is already there.
# first, check if the notebook filename is actually its DataFed ID, in which case it already exists in DataFed
# first, make sure the notebook file is given (not None), otherwise there is no notebook specified to upload
# so just don't upload a notebook but proceed to saving the checkpoints as usual
if self.__file__ is not None:
# check whether the notebook file name is a DataFed ID
if self.__file__.startswith("d/"):
self.notebook_record_id = self.__file__
# if the notebook filename is not a DataFed ID, check if a notebook of the same name exists at the DataFed file path
else:
try:
# this will fail if it doesn't find a match, meaning that the notebook does not already exists on DataFed
self.notebook_record_id = (
self.df_api.get_notebook_DataFed_ID_from_path_and_title(
self.__file__
)
)
except Exception:
# the notebook is not already in DataFed, so upload it
# output to user
old_checksum = ""
if self.logging:
with open(self.log_file_path, "a") as f:
timestamp = (
datetime.now()
.astimezone()
.strftime("%Y-%m-%d %H:%M:%S")
)
f.write(
f"\n {timestamp} - Uploading notebook {self.__file__} to DataFed..."
)
# generate a checksum (and scipt path) for the notebook
self.notebook_metadata = getNotebookMetadata(self.__file__)
# extract the checksum
new_checksum = self.notebook_metadata["script"]["checksum"]
# if the notebook has a DataFed record ID, extract the checksum and compare to the new checksum
if self.notebook_record_id is not None:
notebook_DataFed_metadata = json.loads(
self.df_api.dataView(self.notebook_record_id)[0].data[0].metadata
)
old_checksum = notebook_DataFed_metadata["script"]["checksum"]
if new_checksum != old_checksum:
# do the uploading
if self.logging:
with open(self.log_file_path, "a") as f:
timestamp = (
datetime.now().astimezone().strftime("%Y-%m-%d %H:%M:%S")
)
f.write(
f"\n {timestamp} - Uploading notebook {self.__file__} to DataFed..."
)
current_user, current_time = self.getUserClock()
self.notebook_metadata = self.notebook_metadata | {
"user": current_user,
"timestamp": current_time,
}
# store the dataset Datafed ID in self.dataset_id. Upload the dataset to DataFed if necessary
self.dataset_id = self.df_api.upload_dataset_to_DataFed()
self.notebook_record_resp = self.df_api.data_record_create(
metadata=self.notebook_metadata,
record_title=self.__file__.split("/")[-1], # .split(".")[0],
deps=self.df_api.addDerivedFrom(self.dataset_id),
)
self.df_api.upload_file(
self.notebook_record_resp[0].data[0].id, self.__file__
)
self.notebook_record_id = self.notebook_record_resp[0].data[0].id
[docs]
def save(
self,
record_file_name,
datafed=True,
local_file_path=None,
local_vars=None,
model_hyperparameters=None,
**kwargs,
):
"""
Saves the model's state dictionary locally unless one has already been saved
and optionally uploads it to DataFed along with the model's metadata.
If you want to upload multiple files to the same DataFed data record you can zip them
together and pass in the local path to the zip file as "local_file_path".
Args:
record_file_name (str): The name of the file to save the model locally.
datafed (bool, optional): If True, the record is uploaded to DataFed. Default is True.
local_file_path (str or Path.PosixPath, optional): The local file path to the directory to save the weights
or to the presaved file to upload to DataFed.
local_vars (list): a list containing the local variables for the model training code, from list(locals().items()). Used to determine the metadata
model_hyperparameters (dict): a dictionary where the keys are the model hyperparameters names and the values are the model hyperparameter names. Used in the saved checkpoint.
**kwargs: Additional metadata or attributes to include in the record.
"""
# include the model architecture state dictionary and model hyperparameters in the checkpoint
if not str(local_file_path).endswith(".zip") and not os.path.exists(
str(local_file_path)
):
checkpoint = self.getModelArchitectureStateDict() | model_hyperparameters
# Save the model state dict locally
torch.save(checkpoint, local_file_path)
if datafed:
# Safely retrieve values and replace with None if undefined or not present
notebook_record_id = (
self.notebook_record_id
if self.notebook_record_id and len(self.notebook_record_id) > 0
else None
)
# Saves the record id to the object
self.notebook_record_id = notebook_record_id
current_checkpoint_id = (
self.current_checkpoint_id
if self.current_checkpoint_id is not None
else None
)
self.dataset_id = self.df_api.upload_dataset_to_DataFed()
# Create a list of IDs, excluding any that are None
if isinstance(self.dataset_id, str):
ids_to_add = [
id
for id in [
notebook_record_id,
current_checkpoint_id,
self.dataset_id,
]
if id is not None
]
elif isinstance(self.dataset_id, list):
ids_to_add = [
id
for id in [notebook_record_id, current_checkpoint_id]
if id is not None
]
for id in self.dataset_id:
if self.dataset_id is not None:
ids_to_add.append(id)
else: # self.dataset_id is None:
ids_to_add = [
id
for id in [notebook_record_id, current_checkpoint_id]
if id is not None
]
# Call the API method with the valid IDs (if any)
if ids_to_add:
deps = self.df_api.addDerivedFrom(ids_to_add)
else:
deps = None # If no valid IDs are present, set deps to None
# if isinstance(deps, list):
# deps = functools.reduce(operator.iconcat, deps, []) #[item for sublist in deps for item in sublist]
# Generate metadata and create a data record in DataFed
metadata = self.getMetadata(
local_vars=local_vars,
model_hyperparameters=model_hyperparameters,
**kwargs,
)
dc_resp = self.df_api.data_record_create(
metadata,
record_title=str(record_file_name),
local_model_path=self.local_model_path,
# weights_file_path = weights_file_path,
# embedding_file_path = embedding_file_path,
# reconstruction_file_path = reconstruction_file_path,
deps=deps,
)
# Upload the saved model to DataFed
self.df_api.upload_file(dc_resp[0].data[0].id, str(local_file_path))
self.current_checkpoint_id = dc_resp[0].data[0].id
[docs]
class InferenceEvaluation:
def __init__(
self,
dataframe,
dataset,
df_api,
root_directory=None,
save_directory="./tmp/",
skip=None,
**Kwargs,
):
self.df = dataframe
self.dataset_id = dataset
self.root_directory = root_directory
self.save_directory = save_directory
self.df_api = df_api
self.skip = skip
self.model = self.build_model(**Kwargs)
# Create a logger
self.logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.WARNING)
[docs]
def file_not_found(self, filename, row):
self.logger.warning(
"{filename} was not found from DataFed using record id {row.id}"
)
print(
f"Attempting to download {filename} from DataFed using record id {row.id}"
)
ds_rep = self.df_api.dataGet(row.id, self.save_directory, wait=True)
if ds_rep[0].task[0].status == 3:
# Download was successful
self.logger.info(f"{filename} was downloaded successfully")
# finds the downloaded file recursively
file_path = find_files_recursive(self.root_directory, filename)
return file_path
else:
# Download was not successful
self.logger.error(f"{filename} could not be downloaded")
# returns a None object
return None
return ds_rep
def _getFileName(self, row):
return self.df_api.getFileName(row.id)
[docs]
@staticmethod
def get_first_entry_if_list(data):
if isinstance(data, list) and len(data) > 0:
return data[0] # Return the first entry if it's a non-empty list
else:
return data
[docs]
def run_inference(self, row):
# retrive the filename from the API datarecords
filename = self._getFileName(row)
# checks if the file can be found in the root directory
file_path = find_files_recursive(self.root_directory, filename)
if len(file_path) == 0:
# if the file is not found, attempt to download it from DataFed
file_path = self.file_not_found(filename, row)
file_path = self.get_first_entry_if_list(file_path)
if file_path is None:
self.logger.info(
f"{filename} could not be downloaded, skipping inference."
)
print(f"{filename} could not be downloaded, skipping inference.")
return None
# load the model
self.model.load(file_path[0])
return self.evaluate(row, file_path)
[docs]
def build_model(self):
"""
Builds and returns the model to be used for inference.
This method should be implemented by the child class to define the specific model architecture
and any necessary configurations.
Returns:
torch.nn.Module: The model object to be used for inference.
"""
raise NotImplementedError(
"Child class must implement this method. This method should return a model object."
)
[docs]
def evaluate(self, row, file_path):
"""
Evaluates the model on the given data. This method should be implemented by the child class.
The parent class does not implement this method.
Args:
row (pd.Series): A row from the dataframe containing metadata and other information.
file_path (str): The path to the file to be used for evaluation.
Returns:
dict: The evaluation results as a dictionary.
"""
raise NotImplementedError(
"Child class must implement this method. This method should return evaluation results as a dictionary."
)
[docs]
def run(self):
for i, row in tqdm(self.df.iterrows(), total=self.df.shape[0]):
# set to restart and skip
if self.skip is not None and i <= self.skip:
continue
# runs the inference
msg = self.run_inference(row)
# if file cannot be found, skip inference
if msg is None:
continue
# updates the metadata of the record
self.df_api.dataUpdate(row.id, metadata=json.dumps(msg))
# logs the success of the inference
self.logger.info(f"Inference for {i} record {row.id} was successful")
[docs]
class TorchViewer(nn.Module):
def __init__(self, DataFed_path, **kwargs):
self.DataFed_path = DataFed_path
self.df_api = DataFed(self.DataFed_path, **kwargs)
[docs]
def getModelCheckpoints(
self,
exclude_metadata="computing",
excluded_keys="script",
non_unique=["id", "timestamp", "total_time"],
format="pandas",
):
"""
Retrieves the metadata record for a specified record ID.
Args:
record_id (str): The ID of the record to retrieve.
exclude_metadata (str, list, or None, optional): Metadata fields to exclude from the extraction record.
excluded_keys (str, list, or None, optional): Keys if the metadata record contains to exclude.
non_unique (str, list, or None, optional): Keys which are expected to be unique independent of record uniqueness - these are not considered when finding unique records.
format (str, optional): The format to return the metadata in. Defaults to "pandas".
Returns:
dict: The metadata record.
"""
return self.df_api.get_metadata(
exclude_metadata=exclude_metadata,
excluded_keys=excluded_keys,
non_unique=non_unique,
format=format,
)