Skip to content
Snippets Groups Projects
model.py 5.56 KiB
Newer Older
  • Learn to ignore specific revisions
  • # Copyright 2019 PrivateStorage.io, LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    
    """
    This module implements models (in the MVC sense) for the client side of
    the storage plugin.
    """
    
    
    from functools import (
        wraps,
    
        connect as _connect,
    
    from twisted.python.filepath import (
        FilePath,
    
    class StoreAddError(Exception):
        def __init__(self, reason):
            self.reason = reason
    
    
    class StoreDirectoryError(Exception):
        def __init__(self, reason):
            self.reason = reason
    
    
    
    class SchemaError(TypeError):
        pass
    
    
    CONFIG_DB_NAME = u"privatestorageio-satauthz-v1.sqlite3"
    
    
    def open_and_initialize(path, required_schema_version, connect=None):
        """
        Open a SQLite3 database for use as a payment reference store.
    
        Create the database and populate it with a schema, if it does not already
        exist.
    
        :param FilePath path: The location of the SQLite3 database file.
    
        :param int required_schema_version: The schema version which must be
            present in the database in order for a SQLite3 connection to be
            returned.
    
        :raise SchemaError: If the schema in the database does not match the
            required schema version.
    
        :return: A SQLite3 connection object for the database at the given path.
        """
        if connect is None:
            connect = _connect
    
        try:
            path.parent().makedirs(ignoreExistingDirectory=True)
        except OSError as e:
            raise StoreDirectoryError(e)
    
        conn = connect(
            path.asBytesMode().path,
            isolation_level="IMMEDIATE",
        )
        with conn:
            cursor = conn.cursor()
            cursor.execute(
    
                # This code knows how to create schema version 1.  This is
                # regardless of what the caller *wants* to find in the database.
    
                """
                CREATE TABLE IF NOT EXISTS [version] AS SELECT 1 AS [version]
                """
            )
            cursor.execute(
                """
                SELECT [version] FROM [version]
                """
            )
    
            [(actual_version,)] = cursor.fetchall()
            if actual_version != required_schema_version:
    
                    "Unexpected database schema version.  Required {}.  Got {}.".format(
                        required_schema_version,
                        actual_version,
    
                    ),
                )
    
            cursor.execute(
                """
                CREATE TABLE IF NOT EXISTS [payment-references] (
    
                    [number] text,
    
                    PRIMARY KEY([number])
    
                )
                """,
            )
        return conn
    
    
    def with_cursor(f):
        @wraps(f)
        def with_cursor(self, *a, **kw):
            with self._connection:
                return f(self, self._connection.cursor(), *a, **kw)
        return with_cursor
    
    
    
    def memory_connect(path, *a, **kw):
        """
        Always connect to an in-memory SQLite3 database.
        """
        return _connect(":memory:", *a, **kw)
    
    
    
    @attr.s(frozen=True)
    class PaymentReferenceStore(object):
        """
        This class implements persistence for payment references.
    
    
        :ivar allmydata.node._Config node_config: The Tahoe-LAFS node configuration object for
    
            the node that owns the persisted payment preferences.
        """
    
        database_path = attr.ib(type=FilePath)
        _connection = attr.ib()
    
        @classmethod
    
        def from_node_config(cls, node_config, connect=None):
    
            db_path = FilePath(node_config.get_private_path(CONFIG_DB_NAME))
            conn = open_and_initialize(
                db_path,
    
                required_schema_version=1,
                connect=connect,
    
            )
            return cls(
                db_path,
                conn,
            )
    
        @with_cursor
        def get(self, cursor, prn):
            cursor.execute(
                """
                SELECT
                    ([number])
                FROM
                    [payment-references]
                WHERE
                    [number] = ?
                """,
                (prn,),
            )
            refs = cursor.fetchall()
            if len(refs) == 0:
    
            return PaymentReference(refs[0][0])
    
        @with_cursor
        def add(self, cursor, prn):
            cursor.execute(
                """
                INSERT OR IGNORE INTO [payment-references] VALUES (?)
                """,
                (prn,)
            )
    
        @with_cursor
        def list(self, cursor):
            cursor.execute(
                """
                SELECT ([number]) FROM [payment-references]
                """,
            )
            refs = cursor.fetchall()
    
                PaymentReference(number)
                for (number,)
                in refs
    
    
    @attr.s
    class PaymentReference(object):
        number = attr.ib()
    
        @classmethod
        def from_json(cls, json):
            values = loads(json)
            version = values.pop(u"version")
            return getattr(cls, "from_json_v{}".format(version))(values)
    
    
        @classmethod
        def from_json_v1(cls, values):
            return cls(**values)
    
    
        def to_json(self):
    
            return dumps(self.marshal())
    
    
        def marshal(self):
            return self.to_json_v1()
    
    
    
        def to_json_v1(self):
            result = attr.asdict(self)
            result[u"version"] = 1
            return result