Skip to content
Please note that GitHub no longer supports your web browser.

We recommend upgrading to the latest Google Chrome or Firefox.

Learn more
Permalink
Browse files

Python formatting, and gitignore additions. (#326)

- Run black and isort on Python files.
- Move Spark config to example file.
- Update gitignore for 7a61f0e
additions.
  • Loading branch information...
ruebot authored and ianmilligan1 committed Jul 18, 2019
1 parent f35d54e commit bd5ef14abd990c707a00b2f4df79756e73200718
@@ -13,3 +13,8 @@ workbench.xmi
build
derby.log
metastore_db
__pycache__/
src/main/python/tf/model.zip
src/main/python/tf/util/spark.conf
src/main/python/tf/model/graph/
src/main/python/tf/model/category/
@@ -1,5 +1,4 @@
from aut.common import WebArchive
from aut.udfs import extract_domain

__all__ = ['WebArchive', 'extract_domain']

__all__ = ["WebArchive", "extract_domain"]
@@ -1,5 +1,6 @@
from pyspark.sql import DataFrame


class WebArchive:
def __init__(self, sc, sqlContext, path):
self.sc = sc
@@ -12,4 +13,3 @@ def pages(self):

def links(self):
return DataFrame(self.loader.extractHyperlinks(self.path), self.sqlContext)

@@ -1,11 +1,13 @@
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType


def extract_domain_func(url):
url = url.replace('http://', '').replace('https://', '')
if '/' in url:
return url.split('/')[0].replace('www.', '')
url = url.replace("http://", "").replace("https://", "")
if "/" in url:
return url.split("/")[0].replace("www.", "")
else:
return url.replace('www.', '')
return url.replace("www.", "")


extract_domain = udf(extract_domain_func, StringType())
@@ -1,13 +1,15 @@
import os
import sys
from util.init import *

from pyspark.sql import DataFrame

from model.object_detection import *
from util.init import *

PYAUT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PYAUT_DIR)

from aut.common import WebArchive
from pyspark.sql import DataFrame


if __name__ == "__main__":
# initialization
@@ -23,11 +25,13 @@
arc = WebArchive(sc, sql_context, args.web_archive)
df = DataFrame(arc.loader.extractImages(arc.path), sql_context)
filter_size = tuple(args.filter_size)
print("height >= %d and width >= %d"%filter_size)
preprocessed = df.filter("height >= %d and width >= %d"%filter_size)
print("height >= %d and width >= %d" % filter_size)
preprocessed = df.filter("height >= %d and width >= %d" % filter_size)

# detection
model_broadcast = detector.broadcast()
detect_udf = detector.get_detect_udf(model_broadcast)
res = preprocessed.select("url", detect_udf(col("bytes")).alias("prediction"), "bytes")
res = preprocessed.select(
"url", detect_udf(col("bytes")).alias("prediction"), "bytes"
)
res.write.json(args.output_path)
@@ -1,13 +1,19 @@
import numpy as np
import argparse

import numpy as np

from model.object_detection import SSDExtractor


def get_args():
parser = argparse.ArgumentParser(description='Extracting images from model output.')
parser.add_argument('--res_dir', help='Path of result (model output) directory.')
parser.add_argument('--output_dir', help='Path of extracted image file output directory.')
parser.add_argument('--threshold', type=float, help='Threshold of detection confidence scores.')
parser = argparse.ArgumentParser(description="Extracting images from model output.")
parser.add_argument("--res_dir", help="Path of result (model output) directory.")
parser.add_argument(
"--output_dir", help="Path of extracted image file output directory."
)
parser.add_argument(
"--threshold", type=float, help="Threshold of detection confidence scores."
)
return parser.parse_args()


@@ -17,23 +17,21 @@ def __init__(self, res_dir, output_dir):
self.res_dir = res_dir
self.output_dir = output_dir


def _extract_and_save(self, rec, class_ids, threshold):
raise NotImplementedError("Please overwrite this method.")


def extract_and_save(self, class_ids, threshold):
if class_ids == "all":
class_ids = list(self.cate_dict.keys())

for idx in class_ids:
cls = self.cate_dict[idx]
check_dir(self.output_dir + "/%s/"%cls, create=True)
check_dir(self.output_dir + "/%s/" % cls, create=True)

for fname in os.listdir(self.res_dir):
if fname.startswith("part-"):
print("Extracting:", self.res_dir+"/"+fname)
with open(self.res_dir+"/"+fname) as f:
print("Extracting:", self.res_dir + "/" + fname)
with open(self.res_dir + "/" + fname) as f:
for line in f:
rec = json.loads(line)
self._extract_and_save(rec, class_ids, threshold)
@@ -43,47 +41,56 @@ class SSD:
def __init__(self, sc, sql_context, args):
self.sc = sc
self.sql_context = sql_context
self.category = load_cate_dict_from_pbtxt("%s/category/mscoco_label_map.pbtxt"%PKG_DIR)
self.checkpoint = "%s/graph/ssd_mobilenet_v1_fpn_640x640/frozen_inference_graph.pb"%PKG_DIR
self.category = load_cate_dict_from_pbtxt(
"%s/category/mscoco_label_map.pbtxt" % PKG_DIR
)
self.checkpoint = (
"%s/graph/ssd_mobilenet_v1_fpn_640x640/frozen_inference_graph.pb" % PKG_DIR
)
self.args = args
with tf.io.gfile.GFile(self.checkpoint, 'rb') as f:
with tf.io.gfile.GFile(self.checkpoint, "rb") as f:
model_params = f.read()
self.model_params = model_params


def broadcast(self):
return self.sc.broadcast(self.model_params)


def get_detect_udf(self, model_broadcast):
def batch_proc(bytes_batch):
with tf.Graph().as_default() as g:
graph_def = tf.GraphDef()
graph_def.ParseFromString(model_broadcast.value)
tf.import_graph_def(graph_def, name='')
image_tensor = g.get_tensor_by_name('image_tensor:0')
detection_scores = g.get_tensor_by_name('detection_scores:0')
detection_classes = g.get_tensor_by_name('detection_classes:0')
tf.import_graph_def(graph_def, name="")
image_tensor = g.get_tensor_by_name("image_tensor:0")
detection_scores = g.get_tensor_by_name("detection_scores:0")
detection_classes = g.get_tensor_by_name("detection_classes:0")

with tf.Session().as_default() as sess:
result = []
image_size = (640, 640)
images = np.array([img2np(b, image_size) for b in bytes_batch])
res = sess.run([detection_scores, detection_classes], feed_dict={image_tensor: images})
res = sess.run(
[detection_scores, detection_classes],
feed_dict={image_tensor: images},
)
for i in range(res[0].shape[0]):
result.append([res[0][i], res[1][i]])
return pd.Series(result)
return pandas_udf(ArrayType(ArrayType(FloatType())), PandasUDFType.SCALAR)(batch_proc)

return pandas_udf(ArrayType(ArrayType(FloatType())), PandasUDFType.SCALAR)(
batch_proc
)


class SSDExtractor(ImageExtractor):
def __init__(self, res_dir, output_dir):
super().__init__(res_dir, output_dir)
self.cate_dict = load_cate_dict_from_pbtxt("%s/category/mscoco_label_map.pbtxt"%PKG_DIR)

self.cate_dict = load_cate_dict_from_pbtxt(
"%s/category/mscoco_label_map.pbtxt" % PKG_DIR
)

def _extract_and_save(self, rec, class_ids, threshold):
pred = rec['prediction']
pred = rec["prediction"]
scores = np.array(pred[0])
classes = np.array(pred[1])
valid_classes = np.unique(classes[scores >= threshold])
@@ -102,8 +109,7 @@ def _extract_and_save(self, rec, class_ids, threshold):
cls = self.cate_dict[cls_idx]
try:
img = str2img(rec["bytes"])
img.save(self.output_dir+ "/%s/"%cls + url_parse(rec["url"]))
img.save(self.output_dir + "/%s/" % cls + url_parse(rec["url"]))
except:
fname = self.output_dir+ "/%s/"%cls + url_parse(rec["url"])
fname = self.output_dir + "/%s/" % cls + url_parse(rec["url"])
print("Failing to save:", fname)

@@ -7,7 +7,7 @@


def str2img(byte_str):
return Image.open(io.BytesIO(base64.b64decode(bytes(byte_str, 'utf-8'))))
return Image.open(io.BytesIO(base64.b64decode(bytes(byte_str, "utf-8"))))


def img2np(byte_str, resize=None):
@@ -22,7 +22,7 @@ def img2np(byte_str, resize=None):
if len(img_shape) == 2:
img = np.stack([img, img, img], axis=-1)
elif img_shape[-1] >= 3:
img = img[:,:,:3]
img = img[:, :, :3]

return img

@@ -58,4 +58,3 @@ def load_cate_dict_from_pbtxt(path, key="id", value="display_name"):
cur_cate = re.findall(r'"(.*?)"', entry[1])[0]
cate_dict[cur_key] = cur_cate
return cate_dict

@@ -1,14 +1,15 @@
import argparse
import os
import re
import zipfile

from pyspark import SparkConf, SparkContext, SQLContext
import re
import os


def init_spark(master, aut_jar):
conf = SparkConf()
conf.set("spark.jars", aut_jar)
conf_path = os.path.dirname(os.path.abspath(__file__))+"/spark.conf"
conf_path = os.path.dirname(os.path.abspath(__file__)) + "/spark.conf"
conf_dict = read_conf(conf_path)
for item, value in conf_dict.items():
conf.set(item, value)
@@ -18,29 +19,63 @@ def init_spark(master, aut_jar):


def get_args():
parser = argparse.ArgumentParser(description='PySpark for Web Archive Image Retrieval.')
parser.add_argument('--web_archive', help='Path to warcs.', default='/tuna1/scratch/nruest/geocites/warcs')
parser.add_argument('--aut_jar', help='Path to compiled aut jar.', default='aut/target/aut-0.17.1-SNAPSHOT-fatjar.jar')
parser.add_argument('--spark', help='Path to Apache Spark.', default='spark-2.3.2-bin-hadoop2.7/bin')
parser.add_argument('--master', help='Apache Spark master IP address and port.', default='spark://127.0.1.1:7077')
parser.add_argument('--img_model', help='Model for image processing.', default='ssd')
parser.add_argument('--filter_size', nargs='+', type=int, help='Filter out images smaller than filter_size', default=[640, 640])
parser.add_argument('--output_path', help='Path to image model output.', default='warc_res')
parser = argparse.ArgumentParser(
description="PySpark for Web Archive Image Retrieval."
)
parser.add_argument(
"--web_archive",
help="Path to warcs.",
default="/tuna1/scratch/nruest/geocites/warcs",
)
parser.add_argument(
"--aut_jar",
help="Path to compiled aut jar.",
default="aut/target/aut-0.17.1-SNAPSHOT-fatjar.jar",
)
parser.add_argument(
"--spark", help="Path to Apache Spark.", default="spark-2.3.2-bin-hadoop2.7/bin"
)
parser.add_argument(
"--master",
help="Apache Spark master IP address and port.",
default="spark://127.0.1.1:7077",
)
parser.add_argument(
"--img_model", help="Model for image processing.", default="ssd"
)
parser.add_argument(
"--filter_size",
nargs="+",
type=int,
help="Filter out images smaller than filter_size",
default=[640, 640],
)
parser.add_argument(
"--output_path", help="Path to image model output.", default="warc_res"
)
return parser.parse_args()


def zip_model_module(PYAUT_DIR):
zip = zipfile.ZipFile(os.path.join(PYAUT_DIR, "tf", "model.zip"), "w")
zip.write(os.path.join(PYAUT_DIR, "tf", "model", "__init__.py"), os.path.join("model", "__init__.py"))
zip.write(os.path.join(PYAUT_DIR, "tf", "model", "object_detection.py"), os.path.join("model", "object_detection.py"))
zip.write(os.path.join(PYAUT_DIR, "tf", "model", "preprocess.py"), os.path.join("model", "preprocess.py"))
zip.write(
os.path.join(PYAUT_DIR, "tf", "model", "__init__.py"),
os.path.join("model", "__init__.py"),
)
zip.write(
os.path.join(PYAUT_DIR, "tf", "model", "object_detection.py"),
os.path.join("model", "object_detection.py"),
)
zip.write(
os.path.join(PYAUT_DIR, "tf", "model", "preprocess.py"),
os.path.join("model", "preprocess.py"),
)


def read_conf(conf_path):
conf_dict = {}
with open(conf_path) as f:
for line in f:
conf = re.findall(r'\S+', line.strip())
conf = re.findall(r"\S+", line.strip())
conf_dict[conf[0]] = conf[1]
return conf_dict

File renamed without changes.

0 comments on commit bd5ef14

Please sign in to comment.
You can’t perform that action at this time.