[Date Prev][Date Next] [Thread Prev][Thread Next] [Date Index] [Thread Index]

Bug#762934: Models: Refactor models to contain only ORM abstraction



Hello,

I didn't know someone else was working on this. I am sorry if I am
double-working.
Here is my patch.

Cheers,

Orestis
From 659479c85c91a32e6c645549766b98e4bdcc9717 Mon Sep 17 00:00:00 2001
From: Orestis Ioannou <orestis@oioannou.com>
Date: Thu, 5 Mar 2015 09:03:09 +0100
Subject: [PATCH] Models: Refactor models to contain only ORM abstraction

Closes: #762934
Moved the static queries of each model and non ORM objects to query.py
Created a test suite to verify the Queries.
---
 debsources/app/views.py          |  18 +-
 debsources/models.py             | 379 +-------------------------------------
 debsources/query.py              | 387 +++++++++++++++++++++++++++++++++++++++
 debsources/tests/test_queries.py |  69 +++++++
 4 files changed, 467 insertions(+), 386 deletions(-)
 create mode 100644 debsources/query.py
 create mode 100644 debsources/tests/test_queries.py

diff --git a/debsources/app/views.py b/debsources/app/views.py
index 129d4ad..e32c536 100644
--- a/debsources/app/views.py
+++ b/debsources/app/views.py
@@ -36,8 +36,10 @@ from debsources.excepts import (
     InvalidPackageOrVersionError, FileOrFolderNotFound,
     Http500Error, Http404Error, Http404ErrorSuggestions, Http403Error)
 from debsources.models import (
-    Ctag, Package, PackageName, Checksum, Location, Directory,
-    SourceFile, File, Suite)
+    Ctag, Package, PackageName, Checksum, File, Suite)
+from debsources.query import (Location, Directory,
+    SourceFile, Queries)
+
 from debsources.app.sourcecode import SourceCodeIterator
 from debsources.app.forms import SearchForm
 from debsources.app.infobox import Infobox
@@ -62,7 +64,7 @@ def skeleton_variables():
     update_ts_file = os.path.join(app.config['CACHE_DIR'], 'last-update')
     last_update = local_info.read_update_ts(update_ts_file)
 
-    packages_prefixes = PackageName.get_packages_prefixes(
+    packages_prefixes = Queries.get_packages_prefixes(
         app.config["CACHE_DIR"])
 
     credits_file = os.path.join(app.config["LOCAL_DIR"], "credits.html")
@@ -144,7 +146,7 @@ def deal_404_error(error, mode='html'):
         if isinstance(error, Http404ErrorSuggestions):
             # let's suggest all the possible locations with a different
             # package version
-            possible_versions = PackageName.list_versions(
+            possible_versions = Queries.list_versions(
                 session, error.package)
             suggestions = ['/'.join(filter(None,
                                     [error.package, v.version, error.path]))
@@ -441,7 +443,7 @@ class PrefixView(GeneralView):
         suite = suite.lower()
         if suite == "all":
             suite = ""
-        if prefix in PackageName.get_packages_prefixes(
+        if prefix in Queries.get_packages_prefixes(
                 app.config["CACHE_DIR"]):
             try:
                 if not suite:
@@ -501,7 +503,7 @@ class SourceView(GeneralView):
             suite = ""
         # we list the version with suites it belongs to
         try:
-            versions_w_suites = PackageName.list_versions_w_suites(
+            versions_w_suites = Queries.list_versions_w_suites(
                 session, packagename, suite)
         except InvalidPackageOrVersionError:
             raise Http404Error("%s not found" % packagename)
@@ -605,7 +607,7 @@ class SourceView(GeneralView):
         when 'latest' is provided instead of a version number
         """
         try:
-            versions = PackageName.list_versions(session, package)
+            versions = Queries.list_versions(session, package)
         except InvalidPackageOrVersionError:
             raise Http404Error("%s not found" % package)
         # the latest version is the latest item in the
@@ -876,7 +878,7 @@ class CtagView(GeneralView):
             pagination = None
             slice_ = None
 
-        (count, results) = Ctag.find_ctag(session, ctag, slice_=slice_,
+        (count, results) = Queries.find_ctag(session, ctag, slice_=slice_,
                                           package=package)
         if not self.all_:
             pagination = Pagination(page, offset, count)
diff --git a/debsources/models.py b/debsources/models.py
index e4d05fc..5f57246 100644
--- a/debsources/models.py
+++ b/debsources/models.py
@@ -16,30 +16,17 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 
-import os
-import magic
-import stat
-from collections import namedtuple
-
 from sqlalchemy import Column, ForeignKey
 from sqlalchemy import UniqueConstraint, PrimaryKeyConstraint
 from sqlalchemy import Index
 from sqlalchemy import Boolean, Date, DateTime, Integer, LargeBinary, String
 from sqlalchemy import Enum
-from sqlalchemy import and_
-from sqlalchemy import func as sql_func
 from sqlalchemy.orm import relationship
 from sqlalchemy.ext.declarative import declarative_base
 
-from debian.debian_support import version_compare
 
-from debsources.excepts import InvalidPackageOrVersionError, \
-    FileOrFolderNotFound
 from debsources.consts import VCS_TYPES, SLOCCOUNT_LANGUAGES, \
-    CTAGS_LANGUAGES, METRIC_TYPES, AREAS, PREFIXES_DEFAULT
-from debsources import filetype
-from debsources.debmirror import SourcePackage
-from debsources.consts import SUITES
+    CTAGS_LANGUAGES, METRIC_TYPES
 
 Base = declarative_base()
 
@@ -64,77 +51,6 @@ class PackageName(Base):
     def __repr__(self):
         return self.name
 
-    @staticmethod
-    def get_packages_prefixes(cache_dir):
-        """
-        returns the packages prefixes (a, b, ..., liba, libb, ..., y, z)
-        cache_dir: the cache directory, usually comes from the app config
-        """
-        try:
-            with open(os.path.join(cache_dir, 'pkg-prefixes')) as f:
-                prefixes = [l.rstrip() for l in f]
-        except IOError:
-            prefixes = PREFIXES_DEFAULT
-        return prefixes
-
-    @staticmethod
-    def list_versions(session, packagename, suite=""):
-        """
-        return all versions of a packagename. if suite is specified, only
-        versions contained in that suite are returned.
-        """
-        try:
-            name_id = session.query(PackageName) \
-                             .filter(PackageName.name == packagename) \
-                             .first().id
-        except Exception:
-            raise InvalidPackageOrVersionError(packagename)
-        try:
-            if not suite:
-                versions = session.query(Package) \
-                                  .filter(Package.name_id == name_id).all()
-            else:
-                versions = (session.query(Package)
-                                   .filter(Package.name_id == name_id)
-                                   .filter(sql_func.lower(Suite.suite)
-                                           == suite)
-                                   .filter(Suite.package_id == Package.id)
-                                   .all())
-        except Exception:
-            raise InvalidPackageOrVersionError(packagename)
-        # we sort the versions according to debian versions rules
-        versions = sorted(versions, cmp=version_compare)
-        return versions
-
-    @staticmethod
-    def list_versions_w_suites(session, packagename, suite=""):
-        """
-        return versions with suites. if suite is provided, then only return
-        versions contained in that suite.
-        """
-        # FIXME a left outer join on (Package, Suite) is more preferred.
-        # However, per https://stackoverflow.com/a/997467, custom aggregation
-        # function to concatenate the suite names for the group_by should be
-        # defined on database connection level.
-        versions = PackageName.list_versions(session, packagename, suite)
-        versions_w_suites = []
-        try:
-            for v in versions:
-                suites = session.query(Suite) \
-                                .filter(Suite.package_id == v.id) \
-                                .all()
-                # sort the suites according to debsources.consts.SUITES
-                # use keyfunc to make it py3 compatible
-                suites.sort(key=lambda s: SUITES['all'].index(s.suite))
-                suites = [s.suite for s in suites]
-                v = v.to_dict()
-                v['suites'] = suites
-                versions_w_suites.append(v)
-        except Exception:
-            raise InvalidPackageOrVersionError(packagename)
-
-        return versions_w_suites
-
     def to_dict(self):
         """
         simply serializes a package (because SQLAlchemy query results
@@ -351,41 +267,6 @@ class Ctag(Base):
     #                .filter(Ctag.tag in ctags)
     #                .filter(Ctag
 
-    @staticmethod
-    def find_ctag(session, ctag, package=None, slice_=None):
-        """
-        Returns places in the code where a ctag is found.
-             tuple (count, [sliced] results)
-
-        session: an SQLAlchemy session
-        ctag: the ctag to search
-        package: limit results to package
-        """
-
-        results = (session.query(PackageName.name.label("package"),
-                                 Package.version.label("version"),
-                                 Ctag.file_id.label("file_id"),
-                                 File.path.label("path"),
-                                 Ctag.line.label("line"))
-                   .filter(Ctag.tag == ctag)
-                   .filter(Ctag.package_id == Package.id)
-                   .filter(Ctag.file_id == File.id)
-                   .filter(Package.name_id == PackageName.id)
-                   )
-        if package is not None:
-            results = results.filter(PackageName.name == package)
-
-        results = results.order_by(Ctag.package_id, File.path)
-        count = results.count()
-        if slice_ is not None:
-            results = results.slice(slice_[0], slice_[1])
-        results = [dict(package=res.package,
-                        version=res.version,
-                        path=res.path,
-                        line=res.line)
-                   for res in results.all()]
-        return (count, results)
-
 
 class Metric(Base):
     __tablename__ = 'metrics'
@@ -477,261 +358,3 @@ class HistorySlocCount(Base):
     def __init__(self, suite, timestamp):
         self.suite = suite
         self.timestamp = timestamp
-
-# it's used in Location.get_stat
-# to bypass flake8 complaints, we do not inject the global namespace
-# with globals()["LongFMT"] = namedtuple...
-LongFMT = namedtuple("LongFMT", ["type", "perms", "size", "symlink_dest"])
-
-
-class Location(object):
-    """ a location in a package, can be a directory or a file """
-
-    def _get_debian_path(self, session, package, version, sources_dir):
-        """
-        Returns the Debian path of a package version.
-        For example: main/h
-                     contrib/libz
-        It's the path of a *version*, since a package can have multiple
-        versions in multiple areas (ie main/contrib/nonfree).
-
-        sources_dir: the sources directory, usually comes from the app config
-        """
-        prefix = SourcePackage.pkg_prefix(package)
-
-        try:
-            p_id = session.query(PackageName) \
-                          .filter(PackageName.name == package).first().id
-            varea = session.query(Package) \
-                           .filter(and_(Package.name_id == p_id,
-                                        Package.version == version)) \
-                           .first().area
-        except:
-            # the package or version doesn't exist in the database
-            # BUT: packages are stored for a longer time in the filesystem
-            # to allow codesearch.d.n and others less up-to-date platforms
-            # to point here.
-            # Problem: we don't know the area of such a package
-            # so we try in main, contrib and non-free.
-            for area in AREAS:
-                if os.path.exists(os.path.join(sources_dir, area,
-                                               prefix, package, version)):
-                    return os.path.join(area, prefix)
-
-            raise InvalidPackageOrVersionError("%s %s" % (package, version))
-
-        return os.path.join(varea, prefix)
-
-    def __init__(self, session, sources_dir, sources_static,
-                 package, version="", path=""):
-        """ initialises useful attributes """
-        debian_path = self._get_debian_path(session,
-                                            package, version, sources_dir)
-        self.package = package
-        self.version = version
-        self.path = path
-        self.path_to = os.path.join(package, version, path)
-
-        self.sources_path = os.path.join(
-            sources_dir,
-            debian_path,
-            self.path_to)
-
-        self.version_path = os.path.join(
-            sources_dir,
-            debian_path,
-            package,
-            version)
-
-        if not(os.path.exists(self.sources_path)):
-            raise FileOrFolderNotFound("%s" % (self.path_to))
-
-        self.sources_path_static = os.path.join(
-            sources_static,
-            debian_path,
-            self.path_to)
-
-    def is_dir(self):
-        """ True if self is a directory, False if it's not """
-        return os.path.isdir(self.sources_path)
-
-    def is_file(self):
-        """ True if sels is a file, False if it's not """
-        return os.path.isfile(self.sources_path)
-
-    def is_symlink(self):
-        """ True if a folder/file is a symbolic link file, False if it's not
-        """
-        return os.path.islink(self.sources_path)
-
-    def get_package(self):
-        return self.package
-
-    def get_version(self):
-        return self.version
-
-    def get_path(self):
-        return self.path
-
-    def get_deepest_element(self):
-        if self.version == "":
-            return self.package
-        elif self.path == "":
-            return self.version
-        else:
-            return self.path.split("/")[-1]
-
-    def get_path_to(self):
-        return self.path_to.rstrip("/")
-
-    @staticmethod
-    def get_stat(sources_path):
-        """
-        Returns the filetype and permissions of the folder/file
-        on the disk, unix-styled.
-        """
-        # When porting to Python3, use stat.filemode directly
-        sources_stat = os.lstat(sources_path)
-        sources_mode, sources_size = sources_stat.st_mode, sources_stat.st_size
-        perm_flags = [
-            (stat.S_IRUSR, "r", "-"),
-            (stat.S_IWUSR, "w", "-"),
-            (stat.S_IXUSR, "x", "-"),
-            (stat.S_IRGRP, "r", "-"),
-            (stat.S_IWGRP, "w", "-"),
-            (stat.S_IXGRP, "x", "-"),
-            (stat.S_IROTH, "r", "-"),
-            (stat.S_IWOTH, "w", "-"),
-            (stat.S_IXOTH, "x", "-"),
-            ]
-        # XXX these flags should be enough.
-        type_flags = [
-            (stat.S_ISLNK, "l"),
-            (stat.S_ISREG, "-"),
-            (stat.S_ISDIR, "d"),
-            ]
-        # add the file type: d/l/-
-        file_type = " "
-        for ft, sign in type_flags:
-            if ft(sources_mode):
-                file_type = sign
-                break
-        file_perms = ""
-        for (flag, do_true, do_false) in perm_flags:
-            file_perms += do_true if (sources_mode & flag) else do_false
-
-        file_size = sources_size
-
-        symlink_dest = None
-        if file_type == "l":
-            symlink_dest = os.readlink(sources_path)
-
-        return vars(LongFMT(file_type, file_perms, file_size, symlink_dest))
-
-    @staticmethod
-    def get_path_links(endpoint, path_to):
-        """
-        returns the path hierarchy with urls, to use with 'You are here:'
-        [(name, url(name)), (...), ...]
-        """
-        path_dict = path_to.split('/')
-        pathl = []
-
-        # we import flask here, in order to permit the use of this module
-        # without requiring the user to have flask (e.g. bin/debsources-update
-        # can run in another machine without flask, because it doesn't use
-        # this method)
-        from flask import url_for
-
-        for (i, p) in enumerate(path_dict):
-            pathl.append((p, url_for(endpoint,
-                                     path_to='/'.join(path_dict[:i+1]))))
-        return pathl
-
-
-class Directory(object):
-    """ a folder in a package """
-
-    def __init__(self, location, toplevel=False):
-        # if the directory is a toplevel one, we remove the .pc folder
-        self.sources_path = location.sources_path
-        self.toplevel = toplevel
-        self.location = location
-
-    def get_listing(self):
-        """
-        returns the list of folders/files in a directory,
-        along with their type (directory/file)
-        in a tuple (name, type)
-        """
-        def get_type(f):
-            if os.path.isdir(os.path.join(self.sources_path, f)):
-                return "directory"
-            else:
-                return "file"
-        get_stat, join_path = self.location.get_stat, os.path.join
-        listing = sorted(dict(name=f, type=get_type(f),
-                              stat=get_stat(join_path(self.sources_path, f)))
-                         for f in os.listdir(self.sources_path))
-        if self.toplevel:
-            listing = filter(lambda x: x['name'] != ".pc", listing)
-
-        return listing
-
-
-class SourceFile(object):
-    """ a source file in a package """
-
-    def __init__(self, location):
-        self.location = location
-        self.sources_path = location.sources_path
-        self.sources_path_static = location.sources_path_static
-        self.mime = self._find_mime()
-
-    def _find_mime(self):
-        """ returns the mime encoding and type of a file """
-        mime = magic.open(magic.MIME_TYPE)
-        mime.load()
-        type_ = mime.file(self.sources_path)
-        mime.close()
-        mime = magic.open(magic.MIME_ENCODING)
-        mime.load()
-        encoding = mime.file(self.sources_path)
-        mime.close()
-        return dict(encoding=encoding, type=type_)
-
-    def get_mime(self):
-        return self.mime
-
-    def get_sha256sum(self, session):
-        """
-        Queries the DB and returns the shasum of the file.
-        """
-        shasum = session.query(Checksum.sha256) \
-                        .filter(Checksum.package_id == Package.id) \
-                        .filter(Package.name_id == PackageName.id) \
-                        .filter(File.id == Checksum.file_id) \
-                        .filter(PackageName.name == self.location.package) \
-                        .filter(Package.version == self.location.version) \
-                        .filter(File.path == str(self.location.path)) \
-                        .first()
-        # WARNING: in the DB path is binary, and here
-        # location.path is unicode, because the path comes from
-        # the URL. TODO: check with non-unicode paths
-        if shasum:
-            shasum = shasum[0]
-        return shasum
-
-    def istextfile(self):
-        """True if self is a text file, False if it's not.
-
-        """
-        return filetype.is_text_file(self.mime['type'])
-        # for substring in text_file_mimes:
-        #     if substring in self.mime['type']:
-        #         return True
-        # return False
-
-    def get_raw_url(self):
-        """ return the raw url on disk (e.g. data/main/a/azerty/foo.bar) """
-        return self.sources_path_static
diff --git a/debsources/query.py b/debsources/query.py
new file mode 100644
index 0000000..b8ee557
--- /dev/null
+++ b/debsources/query.py
@@ -0,0 +1,387 @@
+import os
+import stat
+from collections import namedtuple
+
+from debian.debian_support import version_compare
+
+from debsources import filetype
+
+from debsources.consts import AREAS, PREFIXES_DEFAULT
+from debsources.consts import SUITES
+
+from debsources.debmirror import SourcePackage
+
+from debsources.excepts import FileOrFolderNotFound, \
+    InvalidPackageOrVersionError
+
+from debsources.models import (
+    Checksum, Ctag, File, Package, PackageName, Suite)
+
+import magic
+
+from sqlalchemy import and_
+from sqlalchemy import func as sql_func
+
+
+LongFMT = namedtuple("LongFMT", ["type", "perms", "size", "symlink_dest"])
+
+
+class Location(object):
+    """ a location in a package, can be a directory or a file """
+
+    def _get_debian_path(self, session, package, version, sources_dir):
+        """
+        Returns the Debian path of a package version.
+        For example: main/h
+                     contrib/libz
+        It's the path of a *version*, since a package can have multiple
+        versions in multiple areas (ie main/contrib/nonfree).
+
+        sources_dir: the sources directory, usually comes from the app config
+        """
+        prefix = SourcePackage.pkg_prefix(package)
+
+        try:
+            p_id = session.query(PackageName) \
+                          .filter(PackageName.name == package).first().id
+            varea = session.query(Package) \
+                           .filter(and_(Package.name_id == p_id,
+                                        Package.version == version)) \
+                           .first().area
+        except:
+            # the package or version doesn't exist in the database
+            # BUT: packages are stored for a longer time in the filesystem
+            # to allow codesearch.d.n and others less up-to-date platforms
+            # to point here.
+            # Problem: we don't know the area of such a package
+            # so we try in main, contrib and non-free.
+            for area in AREAS:
+                if os.path.exists(os.path.join(sources_dir, area,
+                                               prefix, package, version)):
+                    return os.path.join(area, prefix)
+
+            raise InvalidPackageOrVersionError("%s %s" % (package, version))
+
+        return os.path.join(varea, prefix)
+
+    def __init__(self, session, sources_dir, sources_static,
+                 package, version="", path=""):
+        """ initialises useful attributes """
+        debian_path = self._get_debian_path(session,
+                                            package, version, sources_dir)
+        self.package = package
+        self.version = version
+        self.path = path
+        self.path_to = os.path.join(package, version, path)
+
+        self.sources_path = os.path.join(
+            sources_dir,
+            debian_path,
+            self.path_to)
+
+        self.version_path = os.path.join(
+            sources_dir,
+            debian_path,
+            package,
+            version)
+
+        if not(os.path.exists(self.sources_path)):
+            raise FileOrFolderNotFound("%s" % (self.path_to))
+
+        self.sources_path_static = os.path.join(
+            sources_static,
+            debian_path,
+            self.path_to)
+
+    def is_dir(self):
+        """ True if self is a directory, False if it's not """
+        return os.path.isdir(self.sources_path)
+
+    def is_file(self):
+        """ True if sels is a file, False if it's not """
+        return os.path.isfile(self.sources_path)
+
+    def is_symlink(self):
+        """ True if a folder/file is a symbolic link file, False if it's not
+        """
+        return os.path.islink(self.sources_path)
+
+    def get_package(self):
+        return self.package
+
+    def get_version(self):
+        return self.version
+
+    def get_path(self):
+        return self.path
+
+    def get_deepest_element(self):
+        if self.version == "":
+            return self.package
+        elif self.path == "":
+            return self.version
+        else:
+            return self.path.split("/")[-1]
+
+    def get_path_to(self):
+        return self.path_to.rstrip("/")
+
+    @staticmethod
+    def get_stat(sources_path):
+        """
+        Returns the filetype and permissions of the folder/file
+        on the disk, unix-styled.
+        """
+        # When porting to Python3, use stat.filemode directly
+        sources_stat = os.lstat(sources_path)
+        sources_mode, sources_size = sources_stat.st_mode, sources_stat.st_size
+        perm_flags = [
+            (stat.S_IRUSR, "r", "-"),
+            (stat.S_IWUSR, "w", "-"),
+            (stat.S_IXUSR, "x", "-"),
+            (stat.S_IRGRP, "r", "-"),
+            (stat.S_IWGRP, "w", "-"),
+            (stat.S_IXGRP, "x", "-"),
+            (stat.S_IROTH, "r", "-"),
+            (stat.S_IWOTH, "w", "-"),
+            (stat.S_IXOTH, "x", "-"),
+            ]
+        # XXX these flags should be enough.
+        type_flags = [
+            (stat.S_ISLNK, "l"),
+            (stat.S_ISREG, "-"),
+            (stat.S_ISDIR, "d"),
+            ]
+        # add the file type: d/l/-
+        file_type = " "
+        for ft, sign in type_flags:
+            if ft(sources_mode):
+                file_type = sign
+                break
+        file_perms = ""
+        for (flag, do_true, do_false) in perm_flags:
+            file_perms += do_true if (sources_mode & flag) else do_false
+
+        file_size = sources_size
+
+        symlink_dest = None
+        if file_type == "l":
+            symlink_dest = os.readlink(sources_path)
+
+        return vars(LongFMT(file_type, file_perms, file_size, symlink_dest))
+
+    @staticmethod
+    def get_path_links(endpoint, path_to):
+        """
+        returns the path hierarchy with urls, to use with 'You are here:'
+        [(name, url(name)), (...), ...]
+        """
+        path_dict = path_to.split('/')
+        pathl = []
+
+        # we import flask here, in order to permit the use of this module
+        # without requiring the user to have flask (e.g. bin/debsources-update
+        # can run in another machine without flask, because it doesn't use
+        # this method)
+        from flask import url_for
+
+        for (i, p) in enumerate(path_dict):
+            pathl.append((p, url_for(endpoint,
+                                     path_to='/'.join(path_dict[:i+1]))))
+        return pathl
+
+
+class Directory(object):
+    """ a folder in a package """
+
+    def __init__(self, location, toplevel=False):
+        # if the directory is a toplevel one, we remove the .pc folder
+        self.sources_path = location.sources_path
+        self.toplevel = toplevel
+        self.location = location
+
+    def get_listing(self):
+        """
+        returns the list of folders/files in a directory,
+        along with their type (directory/file)
+        in a tuple (name, type)
+        """
+        def get_type(f):
+            if os.path.isdir(os.path.join(self.sources_path, f)):
+                return "directory"
+            else:
+                return "file"
+        get_stat, join_path = self.location.get_stat, os.path.join
+        listing = sorted(dict(name=f, type=get_type(f),
+                              stat=get_stat(join_path(self.sources_path, f)))
+                         for f in os.listdir(self.sources_path))
+        if self.toplevel:
+            listing = filter(lambda x: x['name'] != ".pc", listing)
+
+        return listing
+
+
+class SourceFile(object):
+    """ a source file in a package """
+
+    def __init__(self, location):
+        self.location = location
+        self.sources_path = location.sources_path
+        self.sources_path_static = location.sources_path_static
+        self.mime = self._find_mime()
+
+    def _find_mime(self):
+        """ returns the mime encoding and type of a file """
+        mime = magic.open(magic.MIME_TYPE)
+        mime.load()
+        type_ = mime.file(self.sources_path)
+        mime.close()
+        mime = magic.open(magic.MIME_ENCODING)
+        mime.load()
+        encoding = mime.file(self.sources_path)
+        mime.close()
+        return dict(encoding=encoding, type=type_)
+
+    def get_mime(self):
+        return self.mime
+
+    def get_sha256sum(self, session):
+        """
+        Queries the DB and returns the shasum of the file.
+        """
+        shasum = session.query(Checksum.sha256) \
+                        .filter(Checksum.package_id == Package.id) \
+                        .filter(Package.name_id == PackageName.id) \
+                        .filter(File.id == Checksum.file_id) \
+                        .filter(PackageName.name == self.location.package) \
+                        .filter(Package.version == self.location.version) \
+                        .filter(File.path == str(self.location.path)) \
+                        .first()
+        # WARNING: in the DB path is binary, and here
+        # location.path is unicode, because the path comes from
+        # the URL. TODO: check with non-unicode paths
+        if shasum:
+            shasum = shasum[0]
+        return shasum
+
+    def istextfile(self):
+        """True if self is a text file, False if it's not.
+
+        """
+        return filetype.is_text_file(self.mime['type'])
+        # for substring in text_file_mimes:
+        #     if substring in self.mime['type']:
+        #         return True
+        # return False
+
+    def get_raw_url(self):
+        """ return the raw url on disk (e.g. data/main/a/azerty/foo.bar) """
+        return self.sources_path_static
+
+
+class Queries(object):
+
+    @staticmethod
+    def get_packages_prefixes(cache_dir):
+        """
+        returns the packages prefixes (a, b, ..., liba, libb, ..., y, z)
+        cache_dir: the cache directory, usually comes from the app config
+        """
+        try:
+            with open(os.path.join(cache_dir, 'pkg-prefixes')) as f:
+                prefixes = [l.rstrip() for l in f]
+        except IOError:
+            prefixes = PREFIXES_DEFAULT
+        return prefixes
+
+    @staticmethod
+    def list_versions(session, packagename, suite=""):
+        """
+        return all versions of a packagename. if suite is specified, only
+        versions contained in that suite are returned.
+        """
+        try:
+            name_id = session.query(PackageName) \
+                             .filter(PackageName.name == packagename) \
+                             .first().id
+        except Exception:
+            raise InvalidPackageOrVersionError(packagename)
+        try:
+            if not suite:
+                versions = session.query(Package) \
+                                  .filter(Package.name_id == name_id).all()
+            else:
+                versions = (session.query(Package)
+                                   .filter(Package.name_id == name_id)
+                                   .filter(sql_func.lower(Suite.suite)
+                                           == suite)
+                                   .filter(Suite.package_id == Package.id)
+                                   .all())
+        except Exception:
+            raise InvalidPackageOrVersionError(packagename)
+        # we sort the versions according to debian versions rules
+        versions = sorted(versions, cmp=version_compare)
+        return versions
+
+    @staticmethod
+    def list_versions_w_suites(session, packagename, suite=""):
+        """
+        return versions with suites. if suite is provided, then only return
+        versions contained in that suite.
+        """
+        # FIXME a left outer join on (Package, Suite) is more preferred.
+        # However, per https://stackoverflow.com/a/997467, custom aggregation
+        # function to concatenate the suite names for the group_by should be
+        # defined on database connection level.
+        versions = Queries.list_versions(session, packagename, suite)
+        versions_w_suites = []
+        try:
+            for v in versions:
+                suites = session.query(Suite) \
+                                .filter(Suite.package_id == v.id) \
+                                .all()
+                # sort the suites according to debsources.consts.SUITES
+                # use keyfunc to make it py3 compatible
+                suites.sort(key=lambda s: SUITES['all'].index(s.suite))
+                suites = [s.suite for s in suites]
+                v = v.to_dict()
+                v['suites'] = suites
+                versions_w_suites.append(v)
+        except Exception:
+            raise InvalidPackageOrVersionError(packagename)
+        return versions_w_suites
+
+    @staticmethod
+    def find_ctag(session, ctag, package=None, slice_=None):
+        """
+        Returns places in the code where a ctag is found.
+             tuple (count, [sliced] results)
+
+        session: an SQLAlchemy session
+        ctag: the ctag to search
+        package: limit results to package
+        """
+
+        results = (session.query(PackageName.name.label("package"),
+                                 Package.version.label("version"),
+                                 Ctag.file_id.label("file_id"),
+                                 File.path.label("path"),
+                                 Ctag.line.label("line"))
+                   .filter(Ctag.tag == ctag)
+                   .filter(Ctag.package_id == Package.id)
+                   .filter(Ctag.file_id == File.id)
+                   .filter(Package.name_id == PackageName.id)
+                   )
+        if package is not None:
+            results = results.filter(PackageName.name == package)
+
+        results = results.order_by(Ctag.package_id, File.path)
+        count = results.count()
+        if slice_ is not None:
+            results = results.slice(slice_[0], slice_[1])
+        results = [dict(package=res.package,
+                        version=res.version,
+                        path=res.path,
+                        line=res.line)
+                   for res in results.all()]
+        return (count, results)
diff --git a/debsources/tests/test_queries.py b/debsources/tests/test_queries.py
new file mode 100644
index 0000000..0fb7094
--- /dev/null
+++ b/debsources/tests/test_queries.py
@@ -0,0 +1,69 @@
+import unittest
+
+from nose.plugins.attrib import attr
+
+from debsources.query import Queries
+
+from debsources.tests.db_testing import DbTestFixture
+from debsources.tests.testdata import TEST_DB_NAME
+
+
+@attr('Queries')
+class QueriesTest(unittest.TestCase, DbTestFixture):
+
+    @classmethod
+    def setUpClass(cls):
+        cls.db_setup_cls()
+
+        # creates an app object, which is used to run queries
+        from debsources.app import app_wrapper
+
+        # erases a few configuration parameters needed for testing:
+        uri = "postgresql:///" + TEST_DB_NAME
+        app_wrapper.app.config["SQLALCHEMY_DATABASE_URI"] = uri
+        app_wrapper.app.config['LIST_OFFSET'] = 5
+        app_wrapper.app.testing = True
+
+        app_wrapper.go()
+
+        cls.app = app_wrapper.app.test_client()
+        cls.app_wrapper = app_wrapper
+
+    @classmethod
+    def tearDownClass(cls):
+        cls.app_wrapper.engine.dispose()
+        cls.db_teardown_cls()
+
+    def test_packages_prefixes(self):
+        self.assertEqual(Queries.get_packages_prefixes(
+            self.app_wrapper.app.config["CACHE_DIR"]),
+            ['b', 'd', 'f', 'g', 'l', 'libc', 'm', 'n', 'o', 's', 'u'])
+
+    def test_list_versions(self):
+        # Test without suit
+        packages = Queries.list_versions(self.session, "gnubg")
+        self.assertEqual([p.version for p in packages],
+                         ["0.90+20091206-4", "0.90+20120429-1", "1.02.000-2"])
+
+        # Test with suit
+        packages = Queries.list_versions(self.session, "gnubg", "wheezy")
+        self.assertEqual([p.version for p in packages], ["0.90+20120429-1"])
+
+        # Test returning suites without suit as parameter
+        self.assertTrue({'suites': [u'wheezy'], 'version': u'0.90+20120429-1',
+                         'area': u'main'} in
+                        Queries.list_versions_w_suites(self.session, "gnubg"))
+
+        # Test returning suites with a suit as parameter
+        self.assertEqual(Queries.list_versions_w_suites(self.session, "gnubg", "jessie"),
+                        [{'suites': [u'jessie', u'sid'],
+                        'version': u'1.02.000-2', 'area': u'main'}])
+
+    def test_find_ctag(self):
+        self.assertEqual(Queries.find_ctag(self.session, "swap")[0], 8)
+
+        ctags = Queries.find_ctag(self.session, "swap", "gnubg")
+        self.assertEqual(ctags[0], 5)
+        self.assertTrue({'path': 'eval.c', 'line': 1747,
+                        'version': u'0.90+20091206-4', 'package': u'gnubg'}
+                        in ctags[1])
-- 
2.1.4

Attachment: signature.asc
Description: OpenPGP digital signature


Reply to: