Source code for

from __future__ import unicode_literals

import struct
from random import choice
from os import path, makedirs
from requests import get
from requests.exceptions import ConnectionError
from PIL import Image, ImageDraw


CACHE_DIR = path.join(".",".quickdrawcache")

[docs]class QuickDrawData(): """ Allows interaction with the Google Quick, Draw! data set, downloads Quick Draw data from and loads it into memory for easy access and processing. The following example will load the anvil drawings and get a single drawing:: from quickdraw import QuickDrawData qd = QuickDrawData() anvil = qd.get_drawing("anvil")"my_anvil.gif") :param bool recognized: If ``True`` only recognized drawings will be loaded, if ``False`` only unrecognized drawings will be loaded, if ``None`` (the default) both recognized and unrecognized drawings will be loaded. :param int max_drawings: The maximum number of drawings to be loaded into memory, defaults to 1000. :param bool refresh_data: If ``True`` forces data to be downloaded even if it has been downloaded before, defaults to ``False``. :param bool jit_loading: If ``True`` (the default) only downloads and loads data into memory when it is required (jit = just in time). If ``False`` all drawings will be downloaded and loaded into memory. :param bool print_messages: If ``True`` (the default), status messages will be printed stating when data is being downloaded or loaded. :param string cache_dir: Specify a cache directory to use when downloading data files, defaults to `./.quickdrawcache`. """ def __init__( self, recognized=None, max_drawings=1000, refresh_data=False, jit_loading=True, print_messages=True, cache_dir=CACHE_DIR): self._recognized = recognized self._print_messages = print_messages self._refresh_data = refresh_data self._max_drawings = max_drawings self._cache_dir = cache_dir self._drawing_groups = {} # if not jit (just in time) loading, load all drawings if not jit_loading: self.load_all_drawings()
[docs] def get_drawing(self, name, index=None): """ Get a drawing. Returns an instance of :class:`QuickDrawing` representing a single Quick, Draw drawing. :param string name: The name of the drawing to get (anvil, ant, aircraft, etc). :param int index: The index of the drawing to get. If ``None`` (the default) a random drawing will be returned. """ return self.get_drawing_group(name).get_drawing(index)
[docs] def get_drawing_group(self, name): """ Get a group of drawings by name. Returns an instance of :class:`QuickDrawDataGroup`. :param string name: The name of the drawings (anvil, ant, aircraft, etc). """ # has this drawing group been loaded to memory if name not in self._drawing_groups.keys(): drawings = QuickDrawDataGroup( name, recognized=self._recognized, max_drawings=self._max_drawings, refresh_data=self._refresh_data, print_messages=self._print_messages, cache_dir=self._cache_dir) self._drawing_groups[name] = drawings return self._drawing_groups[name]
[docs] def search_drawings(self, name, key_id=None, recognized=None, countrycode=None, timestamp=None): """ Search the drawings. Returns an list of :class:`QuickDrawing` instances representing the matched drawings. Note - search criteria are a compound. Search for all the drawings with the ``countrycode`` "PL" :: from quickdraw import QuickDrawDataGroup anvils = QuickDrawDataGroup("anvil") results = anvils.search_drawings(countrycode="PL") :param string name: The name of the drawings (anvil, ant, aircraft, etc) to search. :param int key_id: The ``key_id`` to such for. If ``None`` (the default) the ``key_id`` is not used. :param bool recognized: To search for drawings which were ``recognized``. If ``None`` (the default) ``recognized`` is not used. :param str countrycode: To search for drawings which with the ``countrycode``. If ``None`` (the default) ``countrycode`` is not used. :param int countrycode: To search for drawings which with the ``timestamp``. If ``None`` (the default) ``timestamp`` is not used. """ return self.get_drawing_group(name).search_drawings(key_id, recognized, countrycode, timestamp)
[docs] def load_all_drawings(self): """ Loads (and downloads if required) all drawings into memory. """ self.load_drawings(self.drawing_names)
[docs] def load_drawings(self, list_of_drawings): """ Loads (and downloads if required) all drawings into memory. :param list list_of_drawings: A list of the drawings to be loaded (anvil, ant, aircraft, etc). """ for drawing_group in list_of_drawings: self.get_drawing_group(drawing_group)
@property def drawing_names(self): """ Returns a list of all the potential drawing names. """ return QUICK_DRAWING_NAMES @property def loaded_drawings(self): """ Returns a list of drawing which have been loaded into memory. """ return list(self._drawing_groups.keys())
[docs]class QuickDrawDataGroup(): """ Allows interaction with a group of Quick, Draw! drawings. The following example will load the ant group of drawings and get a single drawing:: from quickdraw import QuickDrawDataGroup ants = QuickDrawDataGroup("ant") ant = ants.get_drawing()"my_ant.gif") :param string name: The name of the drawings to be loaded (anvil, ant, aircraft, etc). :param bool recognized: If ``True`` only recognized drawings will be loaded, if ``False`` only unrecognized drawings will be loaded, if ``None`` (the default) both recognized and unrecognized drawings will be loaded. :param int max_drawings: The maximum number of drawings to be loaded into memory, defaults to 1000. :param bool refresh_data: If ``True`` forces data to be downloaded even if it has been downloaded before, defaults to `False`. :param bool print_messages: If ``True`` (the default), status messages will be printed stating when data is being downloaded or loaded. :param string cache_dir: Specify a cache directory to use when downloading data files, defaults to ``./.quickdrawcache``. """ def __init__( self, name, recognized=None, max_drawings=1000, refresh_data=False, print_messages=True, cache_dir=CACHE_DIR): if name not in QUICK_DRAWING_NAMES: raise ValueError("{} is not a valid google quick drawing".format(name)) self._name = name self._print_messages = print_messages self._max_drawings = max_drawings self._cache_dir = cache_dir self._recognized = recognized self._drawings = [] # get the binary file for this drawing? filename = path.join(self._cache_dir, QUICK_DRAWING_FILES[name]) # if the binary file doesn't exist or refresh_data is True, download the file if not path.isfile(filename) or refresh_data: # if the cache dir doesnt exist, create it if not path.isdir(self._cache_dir): makedirs(self._cache_dir) # download the binary file url = BINARY_URL + QUICK_DRAWING_FILES[name] self._download_drawings_binary(url, filename) # load the drawings self._load_drawings(filename) def _download_drawings_binary(self, url, filename): try: r = get(url, stream=True) self._print_message("downloading {} from {}".format(self._name, url)) with open(filename, 'wb') as f: for chunk in r.iter_content(chunk_size=1024): if chunk: f.write(chunk) except ConnectionError as e: raise Exception("connection error - you need to be connected to the internet to download {} drawings".format(self._name)) # check file exists if not path.isfile(filename): raise Exception("something went wrong with the download of {} - file not found!".format(self._name)) else: self._print_message("download complete") def _load_drawings(self, filename): self._print_message("loading {} drawings".format(self._name)) binary_file = open(filename, 'rb') self._drawings_cache = [] self._drawing_count = 0 self._current_drawing = -1 drawings_to_load = self._max_drawings if self._max_drawings is None: # bit hacky! but efficient... drawings_to_load = 9999999999999 while self._drawing_count < drawings_to_load: try: key_id, = struct.unpack('Q', countrycode, = struct.unpack('2s', recognized, = struct.unpack('b', timestamp, = struct.unpack('I', n_strokes, = struct.unpack('H', image = [] for i in range(n_strokes): n_points, = struct.unpack('H', fmt = str(n_points) + 'B' x = struct.unpack(fmt, y = struct.unpack(fmt, image.append((x, y)) append_drawing = True if self._recognized is not None: if bool(recognized) != self._recognized: append_drawing = False if append_drawing: self._drawings.append({ 'key_id': key_id, 'countrycode': countrycode, 'recognized': recognized, 'timestamp': timestamp, 'n_strokes': n_strokes, 'image': image }) self._drawing_count += 1 # nothing left to read except struct.error: break self._print_message("load complete") def _print_message(self, message): if self._print_messages: print(message) @property def drawing_count(self): """ Returns the number of drawings loaded. """ return self._drawing_count @property def drawings(self): """ An iterator of all the drawings loaded in this group. Returns a :class:`QuickDrawing` object. Load the anvil group of drawings and iterate through them:: from quickdraw import QuickDrawDataGroup anvils = QuickDrawDataGroup("anvil") for anvil in anvils.drawings: print(anvil) """ while True: self._current_drawing += 1 if self._current_drawing > self._drawing_count - 1: # reached the end to the drawings self._current_drawing = 0 return else: # yield the next drawing try: yield self.get_drawing(index = self._current_drawing) except StopIteration: return
[docs] def get_drawing(self, index=None): """ Get a drawing from this group. Returns an instance of :class:`QuickDrawing` representing a single Quick, Draw drawing. Get a single anvil drawing:: from quickdraw import QuickDrawDataGroup anvils = QuickDrawDataGroup("anvil") anvil = anvils.get_drawing() :param int index: The index of the drawing to get. If ``None`` (the default) a random drawing will be returned. """ if index is None: return QuickDrawing(self._name, choice(self._drawings)) else: if index < self.drawing_count: return QuickDrawing(self._name, self._drawings[index]) else: raise IndexError("index {} out of range, there are {} drawings".format(index, self.drawing_count))
[docs] def search_drawings(self, key_id=None, recognized=None, countrycode=None, timestamp=None): """ Searches the drawings in this group. Returns an list of :class:`QuickDrawing` instances representing the matched drawings. Note - search criteria are a compound. Search for all the drawings with the ``countrycode`` "PL" :: from quickdraw import QuickDrawDataGroup anvils = QuickDrawDataGroup("anvil") results = anvils.search_drawings(countrycode="PL") :param int key_id: The ``key_id`` to such for. If ``None`` (the default) the ``key_id`` is not used. :param bool recognized: To search for drawings which were ``recognized``. If ``None`` (the default) ``recognized`` is not used. :param str countrycode: To search for drawings which with the ``countrycode``. If ``None`` (the default) ``countrycode`` is not used. :param int countrycode: To search for drawings which with the ``timestamp``. If ``None`` (the default) ``timestamp`` is not used. """ results = [] for drawing in self.drawings: match = True if key_id is not None: if key_id != drawing.key_id: match = False if recognized is not None: if recognized != drawing.recognized: match = False if countrycode is not None: if countrycode != drawing.countrycode: match = False if timestamp is not None: if timestamp != drawing.timestamp: match = False if match: results.append(drawing) return results
[docs]class QuickDrawing(): """ Represents a single Quick, Draw! drawing. """ def __init__(self, name, drawing_data): self._name =name self._drawing_data = drawing_data self._strokes = None self._image = None @property def name(self): """ Returns the name of the drawing (anvil, aircraft, ant, etc). """ return self._name @property def key_id(self): """ Returns the id of the drawing. """ return self._drawing_data["key_id"] @property def countrycode(self): """ Returns the country code for the drawing. """ return self._drawing_data["countrycode"].decode("utf-8") @property def recognized(self): """ Returns a boolean representing whether the drawing was recognized. """ return bool(self._drawing_data["recognized"]) @property def timestamp(self): """ Returns the time the drawing was created (in seconds since the epoch). """ return self._drawing_data["timestamp"] @property def no_of_strokes(self): """ Returns the number of pen strokes used to create the drawing. """ return self._drawing_data["n_strokes"] @property def image_data(self): """ Returns the raw image data as list of strokes with a list of X co-ordinates and a list of Y co-ordinates. Co-ordinates are aligned to the top-left hand corner with values from 0 to 255. See for more information regarding how the data is represented. """ return self._drawing_data["image"] @property def strokes(self): """ Returns a list of pen strokes containing a list of (x,y) coordinates which make up the drawing. To iterate though the strokes data use:: from quickdraw import QuickDrawData qd = QuickDrawData() anvil = qd.get_drawing("anvil") for stroke in anvil.strokes: for x, y in stroke: print("x={} y={}".format(x, y)) """ # load the strokes if self._strokes is None: self._strokes = [] for stroke in self.image_data: points = [] xs = stroke[0] ys = stroke[1] if len(xs) != len(ys): raise Exception("something is wrong, different number of x's and y's") for point in range(len(xs)): x = xs[point] y = ys[point] points.append((x,y)) self._strokes.append(points) return self._strokes @property def image(self): """ Returns a `PIL Image <>`_ object of the drawing on a white background with a black drawing. Alternative image parameters can be set using ``get_image()``. To save the image you would use the ``save`` method:: from quickdraw import QuickDrawData qd = QuickDrawData() anvil = qd.get_drawing("anvil")"my_anvil.gif") """ if self._image is None: self._image = self.get_image() return self._image
[docs] def get_image(self, stroke_color=(0,0,0), stroke_width=2, bg_color=(255,255,255)): """ Get a `PIL Image <>`_ object of the drawing. :param list stroke_color: A list of RGB (red, green, blue) values for the stroke color, defaults to (0,0,0). :param int stroke_color: A width of the stroke, defaults to 2. :param list bg_color: A list of RGB (red, green, blue) values for the background color, defaults to (255,255,255). """ image ="RGB", (255,255), color=bg_color) image_draw = ImageDraw.Draw(image) for stroke in self.strokes: image_draw.line(stroke, fill=stroke_color, width=stroke_width) return image
def __str__(self): return "QuickDrawing key_id={}".format(self.key_id)