Skip to content
Snippets Groups Projects
model.py 5.56 KiB
Newer Older
# Copyright 2019 PrivateStorage.io, LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This module implements models (in the MVC sense) for the client side of
the storage plugin.
"""

from functools import (
    wraps,
    connect as _connect,
from twisted.python.filepath import (
    FilePath,
class StoreAddError(Exception):
    def __init__(self, reason):
        self.reason = reason


class StoreDirectoryError(Exception):
    def __init__(self, reason):
        self.reason = reason


class SchemaError(TypeError):
    pass


CONFIG_DB_NAME = u"privatestorageio-satauthz-v1.sqlite3"

def open_and_initialize(path, required_schema_version, connect=None):
    """
    Open a SQLite3 database for use as a payment reference store.

    Create the database and populate it with a schema, if it does not already
    exist.

    :param FilePath path: The location of the SQLite3 database file.

    :param int required_schema_version: The schema version which must be
        present in the database in order for a SQLite3 connection to be
        returned.

    :raise SchemaError: If the schema in the database does not match the
        required schema version.

    :return: A SQLite3 connection object for the database at the given path.
    """
    if connect is None:
        connect = _connect
    try:
        path.parent().makedirs(ignoreExistingDirectory=True)
    except OSError as e:
        raise StoreDirectoryError(e)

    conn = connect(
        path.asBytesMode().path,
        isolation_level="IMMEDIATE",
    )
    with conn:
        cursor = conn.cursor()
        cursor.execute(
            # This code knows how to create schema version 1.  This is
            # regardless of what the caller *wants* to find in the database.
            """
            CREATE TABLE IF NOT EXISTS [version] AS SELECT 1 AS [version]
            """
        )
        cursor.execute(
            """
            SELECT [version] FROM [version]
            """
        )
        [(actual_version,)] = cursor.fetchall()
        if actual_version != required_schema_version:
                "Unexpected database schema version.  Required {}.  Got {}.".format(
                    required_schema_version,
                    actual_version,
                ),
            )

        cursor.execute(
            """
            CREATE TABLE IF NOT EXISTS [payment-references] (
                [number] text,
                PRIMARY KEY([number])
            )
            """,
        )
    return conn


def with_cursor(f):
    @wraps(f)
    def with_cursor(self, *a, **kw):
        with self._connection:
            return f(self, self._connection.cursor(), *a, **kw)
    return with_cursor


def memory_connect(path, *a, **kw):
    """
    Always connect to an in-memory SQLite3 database.
    """
    return _connect(":memory:", *a, **kw)


@attr.s(frozen=True)
class PaymentReferenceStore(object):
    """
    This class implements persistence for payment references.

    :ivar allmydata.node._Config node_config: The Tahoe-LAFS node configuration object for
        the node that owns the persisted payment preferences.
    """
    database_path = attr.ib(type=FilePath)
    _connection = attr.ib()

    @classmethod
    def from_node_config(cls, node_config, connect=None):
        db_path = FilePath(node_config.get_private_path(CONFIG_DB_NAME))
        conn = open_and_initialize(
            db_path,
            required_schema_version=1,
            connect=connect,
        )
        return cls(
            db_path,
            conn,
        )

    @with_cursor
    def get(self, cursor, prn):
        cursor.execute(
            """
            SELECT
                ([number])
            FROM
                [payment-references]
            WHERE
                [number] = ?
            """,
            (prn,),
        )
        refs = cursor.fetchall()
        if len(refs) == 0:
        return PaymentReference(refs[0][0])

    @with_cursor
    def add(self, cursor, prn):
        cursor.execute(
            """
            INSERT OR IGNORE INTO [payment-references] VALUES (?)
            """,
            (prn,)
        )

    @with_cursor
    def list(self, cursor):
        cursor.execute(
            """
            SELECT ([number]) FROM [payment-references]
            """,
        )
        refs = cursor.fetchall()
            PaymentReference(number)
            for (number,)
            in refs

@attr.s
class PaymentReference(object):
    number = attr.ib()

    @classmethod
    def from_json(cls, json):
        values = loads(json)
        version = values.pop(u"version")
        return getattr(cls, "from_json_v{}".format(version))(values)


    @classmethod
    def from_json_v1(cls, values):
        return cls(**values)


    def to_json(self):
        return dumps(self.marshal())


    def marshal(self):
        return self.to_json_v1()


    def to_json_v1(self):
        result = attr.asdict(self)
        result[u"version"] = 1
        return result