import torch
import numpy as np
import time, os
import tifffile as tif
from datetime import datetime
from zipfile import ZipFile
from pytz import timezone
from train_tools.data_utils.transforms import get_pred_transforms
[docs]
class BasePredictor:
def __init__(
self,
model,
device,
input_path,
output_path,
make_submission=False,
exp_name=None,
algo_params=None,
):
[docs]
self.output_path = output_path
[docs]
self.make_submission = make_submission
[docs]
self.exp_name = exp_name
# Assign algoritm-specific arguments
if algo_params:
self.__dict__.update((k, v) for k, v in algo_params.items())
# Prepare inference environments
self._setups()
@torch.no_grad()
[docs]
def conduct_prediction(self):
self.model.to(self.device)
self.model.eval()
total_time = 0
total_times = []
for img_name in self.img_names:
img_data = self._get_img_data(img_name)
img_data = img_data.to(self.device)
start = time.time()
pred_mask = self._inference(img_data)
pred_mask = self._post_process(pred_mask.squeeze(0).cpu().numpy())
self.write_pred_mask(
pred_mask, self.output_path, img_name, self.make_submission
)
end = time.time()
time_cost = end - start
total_times.append(time_cost)
total_time += time_cost
print(
f"Prediction finished: {img_name}; img size = {img_data.shape}; costing: {time_cost:.2f}s"
)
print(f"\n Total Time Cost: {total_time:.2f}s")
if self.make_submission:
fname = "%s.zip" % self.exp_name
os.makedirs("./submissions", exist_ok=True)
submission_path = os.path.join("./submissions", fname)
with ZipFile(submission_path, "w") as zipObj2:
pred_names = sorted(os.listdir(self.output_path))
for pred_name in pred_names:
pred_path = os.path.join(self.output_path, pred_name)
zipObj2.write(pred_path)
print("\n>>>>> Submission file is saved at: %s\n" % submission_path)
return time_cost
[docs]
def write_pred_mask(self, pred_mask, output_dir, image_name, submission=False):
# All images should contain at least 5 cells
if submission:
if not (np.max(pred_mask) > 5):
print("[!Caution] Only %d Cells Detected!!!\n" % np.max(pred_mask))
file_name = image_name.split(".")[0]
file_name = file_name + "_label.tiff"
file_path = os.path.join(output_dir, file_name)
tif.imwrite(file_path, pred_mask, compression="zlib")
def _setups(self):
self.pred_transforms = get_pred_transforms()
os.makedirs(self.output_path, exist_ok=True)
now = datetime.now(timezone("Asia/Seoul"))
dt_string = now.strftime("%m%d_%H%M")
self.exp_name = (
self.exp_name + dt_string if self.exp_name is not None else dt_string
)
self.img_names = sorted(os.listdir(self.input_path))
def _get_img_data(self, img_name):
img_path = os.path.join(self.input_path, img_name)
img_data = self.pred_transforms(img_path)
img_data = img_data.unsqueeze(0)
return img_data
def _inference(self, img_data):
raise NotImplementedError
def _post_process(self, pred_mask):
raise NotImplementedError