Skip to content
Permalink
Browse files

Add image analysis and extraction w/TensorFlow (#318)

  • Loading branch information...
h324yang authored and ruebot committed Jul 5, 2019
1 parent 5cb05f7 commit 7a61f0e2201fd9e316af0b257dc3dcfd76ea7e25
@@ -0,0 +1,33 @@
import os
import sys
from util.init import *
from model.object_detection import *
PYAUT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PYAUT_DIR)

from aut.common import WebArchive
from pyspark.sql import DataFrame


if __name__ == "__main__":
# initialization
args = get_args()
sys.path.append(args.spark)
conf, sc, sql_context = init_spark(args.master, args.aut_jar)
zip_model_module(PYAUT_DIR)
sc.addPyFile(os.path.join(PYAUT_DIR, "tf", "model.zip"))
if args.img_model == "ssd":
detector = SSD(sc, sql_context, args)

# preprocessing raw images
arc = WebArchive(sc, sql_context, args.web_archive)
df = DataFrame(arc.loader.extractImages(arc.path), sql_context)
filter_size = tuple(args.filter_size)
print("height >= %d and width >= %d"%filter_size)
preprocessed = df.filter("height >= %d and width >= %d"%filter_size)

# detection
model_broadcast = detector.broadcast()
detect_udf = detector.get_detect_udf(model_broadcast)
res = preprocessed.select("url", detect_udf(col("bytes")).alias("prediction"), "bytes")
res.write.json(args.output_path)
@@ -0,0 +1,17 @@
import numpy as np
import argparse
from model.object_detection import SSDExtractor


def get_args():
parser = argparse.ArgumentParser(description='Extracting images from model output.')
parser.add_argument('--res_dir', help='Path of result (model output) directory.')
parser.add_argument('--output_dir', help='Path of extracted image file output directory.')
parser.add_argument('--threshold', type=float, help='Threshold of detection confidence scores.')
return parser.parse_args()


if __name__ == "__main__":
args = get_args()
extractor = SSDExtractor(args.res_dir, args.output_dir)
extractor.extract_and_save(class_ids="all", threshold=args.threshold)
No changes.
@@ -0,0 +1,109 @@
import pickle
import os
import json
import numpy as np
from .preprocess import *
from pyspark.sql.functions import pandas_udf, PandasUDFType, col
from pyspark.sql.types import ArrayType, FloatType
import tensorflow as tf
import pandas as pd


PKG_DIR = os.path.dirname(os.path.abspath(__file__))


class ImageExtractor:
def __init__(self, res_dir, output_dir):
self.res_dir = res_dir
self.output_dir = output_dir


def _extract_and_save(self, rec, class_ids, threshold):
raise NotImplementedError("Please overwrite this method.")


def extract_and_save(self, class_ids, threshold):
if class_ids == "all":
class_ids = list(self.cate_dict.keys())

for idx in class_ids:
cls = self.cate_dict[idx]
check_dir(self.output_dir + "/%s/"%cls, create=True)

for fname in os.listdir(self.res_dir):
if fname.startswith("part-"):
print("Extracting:", self.res_dir+"/"+fname)
with open(self.res_dir+"/"+fname) as f:
for line in f:
rec = json.loads(line)
self._extract_and_save(rec, class_ids, threshold)


class SSD:
def __init__(self, sc, sql_context, args):
self.sc = sc
self.sql_context = sql_context
self.category = load_cate_dict_from_pbtxt("%s/category/mscoco_label_map.pbtxt"%PKG_DIR)
self.checkpoint = "%s/graph/ssd_mobilenet_v1_fpn_640x640/frozen_inference_graph.pb"%PKG_DIR
self.args = args
with tf.io.gfile.GFile(self.checkpoint, 'rb') as f:
model_params = f.read()
self.model_params = model_params


def broadcast(self):
return self.sc.broadcast(self.model_params)


def get_detect_udf(self, model_broadcast):
def batch_proc(bytes_batch):
with tf.Graph().as_default() as g:
graph_def = tf.GraphDef()
graph_def.ParseFromString(model_broadcast.value)
tf.import_graph_def(graph_def, name='')
image_tensor = g.get_tensor_by_name('image_tensor:0')
detection_scores = g.get_tensor_by_name('detection_scores:0')
detection_classes = g.get_tensor_by_name('detection_classes:0')

with tf.Session().as_default() as sess:
result = []
image_size = (640, 640)
images = np.array([img2np(b, image_size) for b in bytes_batch])
res = sess.run([detection_scores, detection_classes], feed_dict={image_tensor: images})
for i in range(res[0].shape[0]):
result.append([res[0][i], res[1][i]])
return pd.Series(result)
return pandas_udf(ArrayType(ArrayType(FloatType())), PandasUDFType.SCALAR)(batch_proc)


class SSDExtractor(ImageExtractor):
def __init__(self, res_dir, output_dir):
super().__init__(res_dir, output_dir)
self.cate_dict = load_cate_dict_from_pbtxt("%s/category/mscoco_label_map.pbtxt"%PKG_DIR)


def _extract_and_save(self, rec, class_ids, threshold):
pred = rec['prediction']
scores = np.array(pred[0])
classes = np.array(pred[1])
valid_classes = np.unique(classes[scores >= threshold])
if valid_classes.shape[0] > 0:
if class_ids != "all":
inter = list(set(valid_classes).intersection(set(class_ids)))
if len(inter) > 0:
valid_classes = np.array(inter)
else:
valid_classes = None
else:
valid_classes = None

if valid_classes is not None:
for cls_idx in valid_classes:
cls = self.cate_dict[cls_idx]
try:
img = str2img(rec["bytes"])
img.save(self.output_dir+ "/%s/"%cls + url_parse(rec["url"]))
except:
fname = self.output_dir+ "/%s/"%cls + url_parse(rec["url"])
print("Failing to save:", fname)

@@ -0,0 +1,61 @@
from PIL import Image
import io
import base64
import os
import numpy as np
import re


def str2img(byte_str):
return Image.open(io.BytesIO(base64.b64decode(bytes(byte_str, 'utf-8'))))


def img2np(byte_str, resize=None):
try:
image = str2img(byte_str)
img = image.convert("RGB")
if resize is not None:
img = img.resize(resize, Image.BILINEAR)
img = np.array(img).astype(np.uint8)
img_shape = np.shape(img)

if len(img_shape) == 2:
img = np.stack([img, img, img], axis=-1)
elif img_shape[-1] >= 3:
img = img[:,:,:3]

return img

except:
if resize is not None:
return np.zeros((resize[0], resize[1], 3))
else:
return np.zeros((1, 1, 3))


def url_parse(url):
return url.split("://")[1].replace("/", "%%%%")


def check_dir(path, create=False):
if os.path.exists(path):
return True
else:
if create:
os.makedirs(path, exist_ok=True)
return False


def load_cate_dict_from_pbtxt(path, key="id", value="display_name"):
cate_dict = {}
with open(path) as f:
for line in f:
entry = line.strip().split(":")
if len(entry) > 1:
if entry[0] == key:
cur_key = int(entry[1])
if entry[0] == value:
cur_cate = re.findall(r'"(.*?)"', entry[1])[0]
cate_dict[cur_key] = cur_cate
return cate_dict

No changes.
@@ -0,0 +1,46 @@
import argparse
import os
import zipfile
from pyspark import SparkConf, SparkContext, SQLContext
import re
import os

def init_spark(master, aut_jar):
conf = SparkConf()
conf.set("spark.jars", aut_jar)
conf_path = os.path.dirname(os.path.abspath(__file__))+"/spark.conf"
conf_dict = read_conf(conf_path)
for item, value in conf_dict.items():
conf.set(item, value)
sc = SparkContext(master, "aut image analysis", conf=conf)
sql_context = SQLContext(sc)
return conf, sc, sql_context


def get_args():
parser = argparse.ArgumentParser(description='PySpark for Web Archive Image Retrieval.')
parser.add_argument('--web_archive', help='Path to warcs.', default='/tuna1/scratch/nruest/geocites/warcs')
parser.add_argument('--aut_jar', help='Path to compiled aut jar.', default='aut/target/aut-0.17.1-SNAPSHOT-fatjar.jar')
parser.add_argument('--spark', help='Path to Apache Spark.', default='spark-2.3.2-bin-hadoop2.7/bin')
parser.add_argument('--master', help='Apache Spark master IP address and port.', default='spark://127.0.1.1:7077')
parser.add_argument('--img_model', help='Model for image processing.', default='ssd')
parser.add_argument('--filter_size', nargs='+', type=int, help='Filter out images smaller than filter_size', default=[640, 640])
parser.add_argument('--output_path', help='Path to image model output.', default='warc_res')
return parser.parse_args()


def zip_model_module(PYAUT_DIR):
zip = zipfile.ZipFile(os.path.join(PYAUT_DIR, "tf", "model.zip"), "w")
zip.write(os.path.join(PYAUT_DIR, "tf", "model", "__init__.py"), os.path.join("model", "__init__.py"))
zip.write(os.path.join(PYAUT_DIR, "tf", "model", "object_detection.py"), os.path.join("model", "object_detection.py"))
zip.write(os.path.join(PYAUT_DIR, "tf", "model", "preprocess.py"), os.path.join("model", "preprocess.py"))


def read_conf(conf_path):
conf_dict = {}
with open(conf_path) as f:
for line in f:
conf = re.findall(r'\S+', line.strip())
conf_dict[conf[0]] = conf[1]
return conf_dict

@@ -0,0 +1,7 @@
spark.sql.execution.arrow.enabled true
spark.sql.execution.arrow.maxRecordsPerBatch 320
spark.executor.memory 16G
spark.cores.max 48
spark.executor.cores 6
spark.driver.memory 64G
spark.task.cpus 6

0 comments on commit 7a61f0e

Please sign in to comment.
You can’t perform that action at this time.