Source code for weaver.database.mongodb

# MongoDB
import decimal
import logging
import uuid
import warnings
from typing import TYPE_CHECKING, overload

import bson
import pymongo
import pymongo.errors
from bson.codec_options import TypeCodec, TypeRegistry

from weaver.database.base import DatabaseInterface
from import (
from weaver.utils import get_settings, is_uuid

    from typing import Any, Optional, Type, Union

    from pymongo.database import Database

    from weaver.database.base import (
    from weaver.typedefs import AnySettingsContainer, JSON

[docs] LOGGER = logging.getLogger(__name__)
# pylint: disable=C0103,invalid-name
[docs] MongoDB = None # type: Optional[Database]
[docs] MongodbStores = frozenset([ MongodbServiceStore, MongodbProcessStore, MongodbJobStore, MongodbQuoteStore, MongodbBillStore, MongodbVaultStore, ])
if TYPE_CHECKING: # pylint: disable=E0601,used-before-assignment
[docs] AnyMongodbStore = Union[ MongodbServiceStore, MongodbProcessStore, MongodbJobStore, MongodbQuoteStore, MongodbBillStore, MongodbVaultStore, ]
AnyMongodbStoreType = Union[ StoreSelector, AnyMongodbStore, Type[MongodbServiceStore], Type[MongodbProcessStore], Type[MongodbJobStore], Type[MongodbQuoteStore], Type[MongodbBillStore], Type[MongodbVaultStore], ]
[docs] class MongoDatabase(DatabaseInterface):
[docs] _revision = 1
[docs] _database = None
[docs] _settings = None
[docs] _stores = None
[docs] type = "mongodb"
def __init__(self, container): # type: (AnySettingsContainer) -> None super(MongoDatabase, self).__init__(container) self._database = get_mongodb_engine(container) self._settings = get_settings(container) self._stores = {} LOGGER.debug("Database [%s] using versions: {MongoDB: %s, pymongo: %s}",, self._database.client.server_info()["version"], pymongo.__version__)
[docs] def reset_store(self, store_type): # type: (StoreSelector) -> AnyMongodbStore store_type = self._get_store_type(store_type) return self._stores.pop(store_type, None)
[docs] def get_store(self, store_type, *store_args, **store_kwargs): # type: (StoreBillsSelector, *Any, **Any) -> MongodbBillStore ...
@overload def get_store(self, store_type, *store_args, **store_kwargs): # type: (StoreQuotesSelector, *Any, **Any) -> MongodbQuoteStore ... @overload def get_store(self, store_type, *store_args, **store_kwargs): # type: (StoreJobsSelector, *Any, **Any) -> MongodbJobStore ... @overload def get_store(self, store_type, *store_args, **store_kwargs): # type: (StoreProcessesSelector, *Any, **Any) -> MongodbProcessStore ... @overload def get_store(self, store_type, *store_args, **store_kwargs): # type: (StoreServicesSelector, *Any, **Any) -> MongodbServiceStore ... @overload def get_store(self, store_type, *store_args, **store_kwargs): # type: (StoreVaultSelector, *Any, **Any) -> MongodbVaultStore ... def get_store(self, store_type, *store_args, **store_kwargs): # type: (StoreSelector, *Any, **Any) -> AnyMongodbStore """ Retrieve a store from the database. :param store_type: type of the store to retrieve/create. :param store_args: additional arguments to pass down to the store. :param store_kwargs: additional keyword arguments to pass down to the store. """ store_type = self._get_store_type(store_type) for store in MongodbStores: if store.type == store_type: if store_type not in self._stores: if "settings" not in store_kwargs: store_kwargs["settings"] = self._settings self._stores[store_type] = store( *store_args, collection=getattr(self.get_session(), store_type), **store_kwargs, ) return self._stores[store_type] raise NotImplementedError(f"Database '{self.type}' cannot find matching store '{store_type}'.")
[docs] def get_session(self): # type: (...) -> Any return self._database
[docs] def get_information(self): # type: (...) -> JSON """ Obtain information about the database implementation. :returns: JSON with parameters: ``{"version": "<version>", "type": "<db_type>"}``. """ result = list(self._database.version.find().limit(1)) revision = result[0]["revision"] if result else 0 return {"version": revision, "type": self.type}
[docs] def is_ready(self): # type: (...) -> bool return self._database is not None and self._settings is not None
[docs] def run_migration(self): # type: (...) -> None """ Runs any necessary data-schema migration steps. """ db_info = self.get_information()"Running database migration as needed for %s", db_info) version = db_info["version"] assert self._revision >= version, "Cannot process future DB revision." for rev in range(version, self._revision): from_to_msg = f"[Migrating revision: {rev} -> {rev + 1}]" if rev == 0:"%s Convert objects with string for UUID-like fields to real UUID types.", from_to_msg) collection = for cur in collection.find({"id": {"$type": "string"}}): collection.update_one( {"_id": cur["_id"]}, {"$set": { "id": uuid.UUID(str(cur["id"])), "task_id": uuid.UUID(str(cur["task_id"])) if is_uuid(cur["task_id"]) else cur["task_id"], "wps_id": uuid.UUID(str(cur["wps_id"])) if is_uuid(cur["wps_id"]) else None }} ) for collection in [self._database.bills, self._database.quotes]: for cur in collection.find({"id": {"$type": "string"}}): collection.update_one({"_id": cur["_id"]}, {"$set": {"id": uuid.UUID(str(cur["id"]))}}) # NOTE: add any needed migration revisions here with (if rev = next-index)... # update and move to next revision self._database.version.update_one({"revision": rev}, {"$set": {"revision": rev + 1}}, upsert=True) db_info["version"] = rev"Database up-to-date with: %s", db_info)
[docs] class DecimalCodec(TypeCodec): """ Converter that will automatically perform necessary encoding/decoding of decimal types for `MongoDB`. """
[docs] python_type = decimal.Decimal
[docs] bson_type = bson.Decimal128
[docs] def transform_python(self, value): # type: (decimal.Decimal) -> bson.Decimal128 return DecimalCodec.bson_type(value)
[docs] def transform_bson(self, value): # type: (bson.Decimal128) -> decimal.Decimal return value.to_decimal()
[docs] def get_mongodb_connection(container): # type: (AnySettingsContainer) -> Database """ Obtains the basic database connection from settings. """ settings = get_settings(container) settings_default = [("", "localhost"), ("mongodb.port", 27017), ("mongodb.db_name", "weaver")] for setting, default in settings_default: if settings.get(setting, None) is None: warnings.warn(f"Setting '{setting}' not defined in registry, using default [{default}].") settings[setting] = default client = pymongo.MongoClient( settings[""], int(settings["mongodb.port"]), connect=False, # Must specify representation since PyMongo 4.0 and also to avoid Python 3.6 error # uuidRepresentation="pythonLegacy", # Require that datetime objects be returned with timezone awareness. # This ensures that missing 'tzinfo' does not get misinterpreted as locale time when # loading objects from DB, since by default 'datetime.datetime' employs 'tzinfo=None' # for locale naive datetime objects, while MongoDB stores Date in ISO-8601 format. tz_aware=True, type_registry=TypeRegistry([DecimalCodec()]), ) return client[settings["mongodb.db_name"]]
[docs] def get_mongodb_engine(container): # type: (AnySettingsContainer) -> Database """ Obtains the database with configuration ready for usage. """ db = get_mongodb_connection(container)"name", unique=True)"url", unique=True) db.processes.create_index("identifier", unique=True)"id", unique=True) db.quotes.create_index("id", unique=True) db.bills.create_index("id", unique=True) return db