Source code for invenio_db.ext

# -*- coding: utf-8 -*-
# This file is part of Invenio.
# Copyright (C) 2015-2018 CERN.
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.

"""Database management for Invenio."""

from __future__ import absolute_import, print_function

import os

import pkg_resources
import sqlalchemy as sa
from flask_alembic import Alembic
from sqlalchemy_utils.functions import get_class_by_table

from .cli import db as db_cmd
from .shared import db
from .utils import versioning_models_registered

[docs]class InvenioDB(object): """Invenio database extension.""" def __init__(self, app=None, **kwargs): """Extension initialization.""" self.alembic = Alembic(run_mkdir=False, command_name='alembic') if app: self.init_app(app, **kwargs)
[docs] def init_app(self, app, **kwargs): """Initialize application object.""" self.init_db(app, **kwargs) app.config.setdefault('ALEMBIC', { 'script_location': pkg_resources.resource_filename( 'invenio_db', 'alembic' ), 'version_locations': [ (, pkg_resources.resource_filename( base_entry.module_name, os.path.join(*base_entry.attrs) )) for base_entry in pkg_resources.iter_entry_points( 'invenio_db.alembic' ) ], }) self.alembic.init_app(app) app.extensions['invenio-db'] = self app.cli.add_command(db_cmd)
[docs] def init_db(self, app, entry_point_group='invenio_db.models', **kwargs): """Initialize Flask-SQLAlchemy extension.""" # Setup SQLAlchemy app.config.setdefault( 'SQLALCHEMY_DATABASE_URI', 'sqlite:///' + os.path.join(app.instance_path, + '.db') ) app.config.setdefault('SQLALCHEMY_ECHO', False) # Initialize Flask-SQLAlchemy extension. database = kwargs.get('db', db) database.init_app(app) # Initialize versioning support. self.init_versioning(app, database, kwargs.get('versioning_manager')) # Initialize model bases if entry_point_group: for base_entry in pkg_resources.iter_entry_points( entry_point_group): base_entry.load() # All models should be loaded by now. sa.orm.configure_mappers() # Ensure that versioning classes have been built. if app.config['DB_VERSIONING']: manager = self.versioning_manager if manager.pending_classes: if not versioning_models_registered(manager, database.Model): manager.builder.configure_versioned_classes() elif 'transaction' not in database.metadata.tables: manager.declarative_base = database.Model manager.create_transaction_model() manager.plugins.after_build_tx_class(manager)
[docs] def init_versioning(self, app, database, versioning_manager=None): """Initialize the versioning support using SQLAlchemy-Continuum.""" try: pkg_resources.get_distribution('sqlalchemy_continuum') except pkg_resources.DistributionNotFound: # pragma: no cover default_versioning = False else: default_versioning = True app.config.setdefault('DB_VERSIONING', default_versioning) if not app.config['DB_VERSIONING']: return if not default_versioning: # pragma: no cover raise RuntimeError( 'Please install extra versioning support first by running ' 'pip install invenio-db[versioning].' ) # Now we can import SQLAlchemy-Continuum. from sqlalchemy_continuum import make_versioned from sqlalchemy_continuum import versioning_manager as default_vm from sqlalchemy_continuum.plugins import FlaskPlugin # Try to guess user model class: if 'DB_VERSIONING_USER_MODEL' not in app.config: # pragma: no cover try: pkg_resources.get_distribution('invenio_accounts') except pkg_resources.DistributionNotFound: user_cls = None else: user_cls = 'User' else: user_cls = app.config.get('DB_VERSIONING_USER_MODEL') plugins = [FlaskPlugin()] if user_cls else [] # Call make_versioned() before your models are defined. self.versioning_manager = versioning_manager or default_vm make_versioned( user_cls=user_cls, manager=self.versioning_manager, plugins=plugins, ) # Register models that have been loaded beforehand. builder = self.versioning_manager.builder for tbl in database.metadata.tables.values(): builder.instrument_versioned_classes( database.mapper, get_class_by_table(database.Model, tbl) )