Permalink
Please
sign in to comment.
Showing
with
273 additions
and 0 deletions.
- +33 −0 src/main/python/tf/detect.py
- +17 −0 src/main/python/tf/extract_images.py
- 0 src/main/python/tf/model/__init__.py
- +109 −0 src/main/python/tf/model/object_detection.py
- +61 −0 src/main/python/tf/model/preprocess.py
- 0 src/main/python/tf/util/__init__.py
- +46 −0 src/main/python/tf/util/init.py
- +7 −0 src/main/python/tf/util/spark.conf
@@ -0,0 +1,33 @@ | |||
import os | |||
import sys | |||
from util.init import * | |||
from model.object_detection import * | |||
PYAUT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |||
sys.path.append(PYAUT_DIR) | |||
|
|||
from aut.common import WebArchive | |||
from pyspark.sql import DataFrame | |||
|
|||
|
|||
if __name__ == "__main__": | |||
# initialization | |||
args = get_args() | |||
sys.path.append(args.spark) | |||
conf, sc, sql_context = init_spark(args.master, args.aut_jar) | |||
zip_model_module(PYAUT_DIR) | |||
sc.addPyFile(os.path.join(PYAUT_DIR, "tf", "model.zip")) | |||
if args.img_model == "ssd": | |||
detector = SSD(sc, sql_context, args) | |||
|
|||
# preprocessing raw images | |||
arc = WebArchive(sc, sql_context, args.web_archive) | |||
df = DataFrame(arc.loader.extractImages(arc.path), sql_context) | |||
filter_size = tuple(args.filter_size) | |||
print("height >= %d and width >= %d"%filter_size) | |||
preprocessed = df.filter("height >= %d and width >= %d"%filter_size) | |||
|
|||
# detection | |||
model_broadcast = detector.broadcast() | |||
detect_udf = detector.get_detect_udf(model_broadcast) | |||
res = preprocessed.select("url", detect_udf(col("bytes")).alias("prediction"), "bytes") | |||
res.write.json(args.output_path) |
@@ -0,0 +1,17 @@ | |||
import numpy as np | |||
import argparse | |||
from model.object_detection import SSDExtractor | |||
|
|||
|
|||
def get_args(): | |||
parser = argparse.ArgumentParser(description='Extracting images from model output.') | |||
parser.add_argument('--res_dir', help='Path of result (model output) directory.') | |||
parser.add_argument('--output_dir', help='Path of extracted image file output directory.') | |||
parser.add_argument('--threshold', type=float, help='Threshold of detection confidence scores.') | |||
return parser.parse_args() | |||
|
|||
|
|||
if __name__ == "__main__": | |||
args = get_args() | |||
extractor = SSDExtractor(args.res_dir, args.output_dir) | |||
extractor.extract_and_save(class_ids="all", threshold=args.threshold) |
No changes.
@@ -0,0 +1,109 @@ | |||
import pickle | |||
import os | |||
import json | |||
import numpy as np | |||
from .preprocess import * | |||
from pyspark.sql.functions import pandas_udf, PandasUDFType, col | |||
from pyspark.sql.types import ArrayType, FloatType | |||
import tensorflow as tf | |||
import pandas as pd | |||
|
|||
|
|||
PKG_DIR = os.path.dirname(os.path.abspath(__file__)) | |||
|
|||
|
|||
class ImageExtractor: | |||
def __init__(self, res_dir, output_dir): | |||
self.res_dir = res_dir | |||
self.output_dir = output_dir | |||
|
|||
|
|||
def _extract_and_save(self, rec, class_ids, threshold): | |||
raise NotImplementedError("Please overwrite this method.") | |||
|
|||
|
|||
def extract_and_save(self, class_ids, threshold): | |||
if class_ids == "all": | |||
class_ids = list(self.cate_dict.keys()) | |||
|
|||
for idx in class_ids: | |||
cls = self.cate_dict[idx] | |||
check_dir(self.output_dir + "/%s/"%cls, create=True) | |||
|
|||
for fname in os.listdir(self.res_dir): | |||
if fname.startswith("part-"): | |||
print("Extracting:", self.res_dir+"/"+fname) | |||
with open(self.res_dir+"/"+fname) as f: | |||
for line in f: | |||
rec = json.loads(line) | |||
self._extract_and_save(rec, class_ids, threshold) | |||
|
|||
|
|||
class SSD: | |||
def __init__(self, sc, sql_context, args): | |||
self.sc = sc | |||
self.sql_context = sql_context | |||
self.category = load_cate_dict_from_pbtxt("%s/category/mscoco_label_map.pbtxt"%PKG_DIR) | |||
self.checkpoint = "%s/graph/ssd_mobilenet_v1_fpn_640x640/frozen_inference_graph.pb"%PKG_DIR | |||
self.args = args | |||
with tf.io.gfile.GFile(self.checkpoint, 'rb') as f: | |||
model_params = f.read() | |||
self.model_params = model_params | |||
|
|||
|
|||
def broadcast(self): | |||
return self.sc.broadcast(self.model_params) | |||
|
|||
|
|||
def get_detect_udf(self, model_broadcast): | |||
def batch_proc(bytes_batch): | |||
with tf.Graph().as_default() as g: | |||
graph_def = tf.GraphDef() | |||
graph_def.ParseFromString(model_broadcast.value) | |||
tf.import_graph_def(graph_def, name='') | |||
image_tensor = g.get_tensor_by_name('image_tensor:0') | |||
detection_scores = g.get_tensor_by_name('detection_scores:0') | |||
detection_classes = g.get_tensor_by_name('detection_classes:0') | |||
|
|||
with tf.Session().as_default() as sess: | |||
result = [] | |||
image_size = (640, 640) | |||
images = np.array([img2np(b, image_size) for b in bytes_batch]) | |||
res = sess.run([detection_scores, detection_classes], feed_dict={image_tensor: images}) | |||
for i in range(res[0].shape[0]): | |||
result.append([res[0][i], res[1][i]]) | |||
return pd.Series(result) | |||
return pandas_udf(ArrayType(ArrayType(FloatType())), PandasUDFType.SCALAR)(batch_proc) | |||
|
|||
|
|||
class SSDExtractor(ImageExtractor): | |||
def __init__(self, res_dir, output_dir): | |||
super().__init__(res_dir, output_dir) | |||
self.cate_dict = load_cate_dict_from_pbtxt("%s/category/mscoco_label_map.pbtxt"%PKG_DIR) | |||
|
|||
|
|||
def _extract_and_save(self, rec, class_ids, threshold): | |||
pred = rec['prediction'] | |||
scores = np.array(pred[0]) | |||
classes = np.array(pred[1]) | |||
valid_classes = np.unique(classes[scores >= threshold]) | |||
if valid_classes.shape[0] > 0: | |||
if class_ids != "all": | |||
inter = list(set(valid_classes).intersection(set(class_ids))) | |||
if len(inter) > 0: | |||
valid_classes = np.array(inter) | |||
else: | |||
valid_classes = None | |||
else: | |||
valid_classes = None | |||
|
|||
if valid_classes is not None: | |||
for cls_idx in valid_classes: | |||
cls = self.cate_dict[cls_idx] | |||
try: | |||
img = str2img(rec["bytes"]) | |||
img.save(self.output_dir+ "/%s/"%cls + url_parse(rec["url"])) | |||
except: | |||
fname = self.output_dir+ "/%s/"%cls + url_parse(rec["url"]) | |||
print("Failing to save:", fname) | |||
|
@@ -0,0 +1,61 @@ | |||
from PIL import Image | |||
import io | |||
import base64 | |||
import os | |||
import numpy as np | |||
import re | |||
|
|||
|
|||
def str2img(byte_str): | |||
return Image.open(io.BytesIO(base64.b64decode(bytes(byte_str, 'utf-8')))) | |||
|
|||
|
|||
def img2np(byte_str, resize=None): | |||
try: | |||
image = str2img(byte_str) | |||
img = image.convert("RGB") | |||
if resize is not None: | |||
img = img.resize(resize, Image.BILINEAR) | |||
img = np.array(img).astype(np.uint8) | |||
img_shape = np.shape(img) | |||
|
|||
if len(img_shape) == 2: | |||
img = np.stack([img, img, img], axis=-1) | |||
elif img_shape[-1] >= 3: | |||
img = img[:,:,:3] | |||
|
|||
return img | |||
|
|||
except: | |||
if resize is not None: | |||
return np.zeros((resize[0], resize[1], 3)) | |||
else: | |||
return np.zeros((1, 1, 3)) | |||
|
|||
|
|||
def url_parse(url): | |||
return url.split("://")[1].replace("/", "%%%%") | |||
|
|||
|
|||
def check_dir(path, create=False): | |||
if os.path.exists(path): | |||
return True | |||
else: | |||
if create: | |||
os.makedirs(path, exist_ok=True) | |||
return False | |||
|
|||
|
|||
def load_cate_dict_from_pbtxt(path, key="id", value="display_name"): | |||
cate_dict = {} | |||
with open(path) as f: | |||
for line in f: | |||
entry = line.strip().split(":") | |||
if len(entry) > 1: | |||
if entry[0] == key: | |||
cur_key = int(entry[1]) | |||
if entry[0] == value: | |||
cur_cate = re.findall(r'"(.*?)"', entry[1])[0] | |||
cate_dict[cur_key] = cur_cate | |||
return cate_dict | |||
|
No changes.
@@ -0,0 +1,46 @@ | |||
import argparse | |||
import os | |||
import zipfile | |||
from pyspark import SparkConf, SparkContext, SQLContext | |||
import re | |||
import os | |||
|
|||
def init_spark(master, aut_jar): | |||
conf = SparkConf() | |||
conf.set("spark.jars", aut_jar) | |||
conf_path = os.path.dirname(os.path.abspath(__file__))+"/spark.conf" | |||
conf_dict = read_conf(conf_path) | |||
for item, value in conf_dict.items(): | |||
conf.set(item, value) | |||
sc = SparkContext(master, "aut image analysis", conf=conf) | |||
sql_context = SQLContext(sc) | |||
return conf, sc, sql_context | |||
|
|||
|
|||
def get_args(): | |||
parser = argparse.ArgumentParser(description='PySpark for Web Archive Image Retrieval.') | |||
parser.add_argument('--web_archive', help='Path to warcs.', default='/tuna1/scratch/nruest/geocites/warcs') | |||
parser.add_argument('--aut_jar', help='Path to compiled aut jar.', default='aut/target/aut-0.17.1-SNAPSHOT-fatjar.jar') | |||
parser.add_argument('--spark', help='Path to Apache Spark.', default='spark-2.3.2-bin-hadoop2.7/bin') | |||
parser.add_argument('--master', help='Apache Spark master IP address and port.', default='spark://127.0.1.1:7077') | |||
parser.add_argument('--img_model', help='Model for image processing.', default='ssd') | |||
parser.add_argument('--filter_size', nargs='+', type=int, help='Filter out images smaller than filter_size', default=[640, 640]) | |||
parser.add_argument('--output_path', help='Path to image model output.', default='warc_res') | |||
return parser.parse_args() | |||
|
|||
|
|||
def zip_model_module(PYAUT_DIR): | |||
zip = zipfile.ZipFile(os.path.join(PYAUT_DIR, "tf", "model.zip"), "w") | |||
zip.write(os.path.join(PYAUT_DIR, "tf", "model", "__init__.py"), os.path.join("model", "__init__.py")) | |||
zip.write(os.path.join(PYAUT_DIR, "tf", "model", "object_detection.py"), os.path.join("model", "object_detection.py")) | |||
zip.write(os.path.join(PYAUT_DIR, "tf", "model", "preprocess.py"), os.path.join("model", "preprocess.py")) | |||
|
|||
|
|||
def read_conf(conf_path): | |||
conf_dict = {} | |||
with open(conf_path) as f: | |||
for line in f: | |||
conf = re.findall(r'\S+', line.strip()) | |||
conf_dict[conf[0]] = conf[1] | |||
return conf_dict | |||
|
@@ -0,0 +1,7 @@ | |||
spark.sql.execution.arrow.enabled true | |||
spark.sql.execution.arrow.maxRecordsPerBatch 320 | |||
spark.executor.memory 16G | |||
spark.cores.max 48 | |||
spark.executor.cores 6 | |||
spark.driver.memory 64G | |||
spark.task.cpus 6 |
0 comments on commit
7a61f0e