Skip to content
Snippets Groups Projects
test_recover.py 13.4 KiB
Newer Older
"""
Tests for ``_zkapauthorizer.recover``, the replication recovery system.
"""

from io import StringIO
from sqlite3 import Connection, connect
from typing import Callable, Dict
from allmydata.node import config_from_string
from attrs import Factory, define, field
from fixtures import TempDir
from hypothesis import assume, given, note, settings
from hypothesis.stateful import (
    RuleBasedStateMachine,
    invariant,
    precondition,
    rule,
    run_state_machine_as_test,
)
from hypothesis.strategies import data, lists, randoms, sampled_from, text
from testtools import TestCase
from testtools.matchers import AfterPreprocessing, Always, Equals, IsInstance
from testtools.twistedsupport import failed, succeeded
from twisted.internet.defer import Deferred
from twisted.python.filepath import FilePath
from zope.interface import implementer

from .._sql import Table, create_table
from ..recover import (
    AlreadyRecovering,
    BrokenRecoverer,
    IRecoverer,
    ISynchronousRecoverer,
    MemorySnapshotRecoverer,
    NullRecoverer,
    RecoveryStages,
    StatefulRecoverer,
    SynchronousStorageSnapshotRecoverer,
    TahoeLAFSRecoverer,
)
from .strategies import deletes, inserts, sql_identifiers, tables, updates


def snapshot(connection: Connection):
    for statement in connection.iterdump():
        yield statement + "\n"


def equals_db(reference):
    return AfterPreprocessing(
        lambda actual: list(actual.iterdump()),
        Equals(list(reference.iterdump())),
    )


class SnapshotMachine(RuleBasedStateMachine):
    """
    Transition rules for a state machine corresponding to the state of a
    SQLite3 database.  Transitions are schema changes, row inserts, row
    updates, row deletions, etc.
    """

    def __init__(self, case, make_recoverer):
        super().__init__()
        self.case = case
        self.connection = connect(":memory:")
        self.make_recoverer = make_recoverer

        self.tables: Dict[str, Table] = {}

    @invariant()
    def snapshot_equals_database(self):
        """
        At all points a snapshot of the database can be used to construct a new
        database with the same contents.
        """
        statements = list(snapshot(self.connection))
        new = connect(":memory:")
        recoverer = self.make_recoverer(statements)

        state = None

        def set_state(new_state):
            nonlocal state
            state = new_state

        recoverer.recover(set_state, None, new)
            state.stage,
            Equals(RecoveryStages.succeeded),
        )

        self.case.assertThat(
            new,
            equals_db(reference=self.connection),
            "source (reference) database iterdump does not equal "
            "sink (actual) database iterdump",
        )

    @rule(
        name=sql_identifiers(),
        table=tables(),
    )
    def create_table(self, name, table):
        """
        Create a new table in the database.
        """
        assume(name not in self.tables)
        self.tables[name] = table
        statement = create_table(name, table)
        note("executing {!r}".format(statement))
        self.connection.execute(statement)

    @precondition(lambda self: len(self.tables) > 0)
    @rule(
        change_types=lists(sampled_from([inserts, deletes, updates]), min_size=1),
        random=randoms(),
        data=data(),
    )
    def modify_rows(self, change_types, random, data):
        """
        Change some rows in some tables.
        """

        for change_type in change_types:
            # Choose a table to impact
            table_name = random.choice(sorted(self.tables))
            # Construct the change
            changes = data.draw(lists(change_type(table_name, self.tables[table_name])))
            # Execute the changes
            for change in changes:
                note(
                    "executing {!r} {!r}".format(change.statement(), change.arguments())
                )
                self.connection.execute(change.statement(), change.arguments())


def run_snapshot_machine(
    case: TestCase, make_recoverer: Callable[[], ISynchronousRecoverer]
) -> None:
    """
    Run ``SnapshotMachine`` as a unit test.
    """
    # Many shallow runs are probably more useful than fewer deep runs.  That
    # is, exercise breadth in preference to depth.
    #
    # Also try to play along with any profile that has been loaded.
    max_examples = settings.default.max_examples * 10
    stateful_step_count = int(max(1, settings.default.stateful_step_count / 10))

    run_state_machine_as_test(
        lambda: SnapshotMachine(case, make_recoverer),
        settings=settings(
            max_examples=max_examples,
            stateful_step_count=stateful_step_count,
        ),
    )


class MemorySnapshotRecovererTests(TestCase):
    Tests for ``MemorySnapshotRecoverer``.
        """
        Test the snapshot/recovery system using a ``MemorySnapshotRecoverer``.
        """
        run_snapshot_machine(self, MemorySnapshotRecoverer)


@implementer(IRecoverer)
@define
class _AsyncRecovererWrapper:
    """
    Adapt ``ISynchronousRecoverer`` to ``IRecoverer``.
    """

    _wrapped: ISynchronousRecoverer

    async def recover(self, set_state, cap, cursor):
        self._wrapped.recover(set_state, cap, cursor)


class RecovererTestsMixin:
    """
    A mixin defining interface tests that any ``IRecoverer`` implementation
    should pass.
    """

    def upload(self, data: bytes) -> str:
        raise NotImplementedError()

    def make_recoverer(self, cap: str) -> IRecoverer:
        raise NotImplementedError()

    def test_recover(self):
        """
        ````IRecoverer.recover`` loads statements from its path into the cursor
        given.
        """
        statements = [
            # Some DDL
            "CREATE TABLE [foo] ( [a] INT );",
            # Some DML
            "INSERT INTO [foo] ([a]) VALUES (1);",
        ]

        # Construct a database we can use to create a snapshot.
        with connect(":memory:") as conn:
            cursor = conn.cursor()
            for statement in statements:
                cursor.execute(statement)

            snapshot_statements = list(snapshot(conn))

        # Create the object under test - the recoverer which can recover from
        # this snapshot.
        recoverer = self.make_recoverer()

        # Put the snapshot somewhere the recoverer will be able to find it.
        cap = self.upload(("\n".join(snapshot_statements) + "\n").encode("ascii"))

        # Create a database we can recover the snapshot into.
        with connect(":memory:") as conn:
            cursor = conn.cursor()

            self.assertThat(
                # Do the recovery.
                Deferred.fromCoroutine(
                    recoverer.recover(lambda state: None, cap, cursor),
                ),
                succeeded(Always()),
            )
            # A snapshot of the recovered database should be the same as a
            # snapshot of the original.
            self.assertThat(
                list(snapshot(conn)),
                Equals(snapshot_statements),
            )


@define
class MemoryGrid:
    """
    An extremely simplified in-memory model of a Tahoe-LAFS storage grid.
    This object allows data to be "uploaded" to it and produces capability
    strings which can then be used to "download" the data from it later on.

    :ivar _counter: An internal counter used to support the creation of
        capability strings.

    :ivar _objects: Storage for all data which has been "uploaded", as a
        mapping from the capability strings to the values.
    """
    _objects: Dict[str, str] = field(default=Factory(dict))

    def upload(self, data: bytes) -> str:
        cap = str(self._counter)
        self._objects[cap] = data
        self._counter += 1
        return cap

    def download(self, cap: str) -> bytes:
        return self._objects[cap]


class SynchronousStorageSnapshotRecovererTests(TestCase, RecovererTestsMixin):
    """
    Tests for ``SynchronousStorageSnapshotRecoverer``.
    """

    def setUp(self):
        super().setUp()
        self.grid = MemoryGrid()
        self.upload = self.grid.upload

    def make_recoverer(self) -> SynchronousStorageSnapshotRecoverer:
        """
        Create a ``SynchronousStorageSnapshotRecoverer`` that resolves
        capabilities from the ``MemoryGrid``.
        """

        def open_from_grid(cap):
            return StringIO(self.grid.download(cap).decode("ascii"))

        # FilesystemSnapshotRecoverer is an ISynchronousRecoverer so wrap it
        # up to make it look like an IRecoverer.  This doesn't actually make
        # it asynchronous, of course, but makes the interfaces line up.
        return _AsyncRecovererWrapper(
            SynchronousStorageSnapshotRecoverer(
                open_from_grid,
class NoSuchCapability(Exception):
    pass


class TahoeLAFSRecovererTests(TestCase, RecovererTestsMixin):
    """
    Tests for ``TahoeLAFSRecoverer``.
    """

    def setUp(self):
        super().setUp()
        self.grid = MemoryGrid()
        self.upload = self.grid.upload
        self.node_dir = FilePath(self.useFixture(TempDir()).join("node"))
        self.node_dir.child("private").makedirs()

    def make_recoverer(self) -> TahoeLAFSRecoverer:
        """
        Create a ``TahoeLAFSRecoverer`` which can see an immutable object
        containing the given statements.
        """

        async def download(client, outpath, api_root, cap):
            try:
                obj = self.grid.download(cap)
            except KeyError:
                raise NoSuchCapability()
        # treq is used by the real download function but we're supplying our
        # own that doesn't need it, so value here should be irrelevant.  maybe
        # the http client should be implied by the download API rather than a
        # parameter it accepts?
        treq = object()
        node_config = config_from_string(self.node_dir.path, "", "")
        return TahoeLAFSRecoverer(treq, node_config, download)

    def test_recover_failed(self):
        """
        If the snapshot data cannot be found then ``IRecoverer.recover`` raises
        the underlying exception.
        """
        recoverer = self.make_recoverer()

        # Just invent a capability.  There will be no associated data and so
        # recovery will fail.
        cap = "abcdef"

        states = []
        record_state = states.append

        with connect(":memory:") as conn:
            cursor = conn.cursor()
            # We expect the recoverer to fail with the exception raised by the
            # downloader for now.  Later we probably want it to inspect the
            # exception and sometimes take a different action.  It will
            # probably never recognize our test-only NoSuchCapability
            # exception though.
                Deferred.fromCoroutine(
                    recoverer.recover(record_state, cap, cursor),
                ),
                failed(
                    AfterPreprocessing(
                        lambda f: f.value,
                        IsInstance(NoSuchCapability),
                    ),
                ),
            )


class StatefulRecovererTests(TestCase):
    """
    Tests for ``StatefulRecoverer``.
    """

    @given(text())
    def test_succeeded_after_recover(self, cap):
        """
        ``StatefulRecoverer`` automatically progresses to the succeeded stage when
        the wrapped recoverer completes without exception.
        """
        recoverer = StatefulRecoverer(NullRecoverer())
        with connect(":memory:") as conn:
            cursor = conn.cursor()
            first = Deferred.fromCoroutine(recoverer.recover(cap, cursor))
            self.assertThat(
                first,
                succeeded(Always()),
            )
            self.assertThat(
                recoverer.state().stage,
                Equals(RecoveryStages.succeeded),
            )

    @given(text())
    def test_failed_after_recover(self, cap):
        """
        ``StatefulRecoverer`` automatically progresses to the failed stage when
        the wrapped recoverer completes with an exception.
        """
        recoverer = StatefulRecoverer(BrokenRecoverer())
        with connect(":memory:") as conn:
            cursor = conn.cursor()
            first = Deferred.fromCoroutine(recoverer.recover(cap, cursor))
            self.assertThat(
                first,
                succeeded(Always()),
            )
            self.assertThat(recoverer.state().stage, Equals(RecoveryStages.failed))

    @given(text())
    def test_cannot_recover_twice(self, cap):
        """
        A second call to ``StatefulRecoverer.recover`` fails with
        ``AlreadyRecovering``.
        """
        recoverer = StatefulRecoverer(NullRecoverer())
        with connect(":memory:") as conn:
            cursor = conn.cursor()
            self.assertThat(
                Deferred.fromCoroutine(recoverer.recover(cap, cursor)),
                succeeded(Always()),
            )
            second = Deferred.fromCoroutine(recoverer.recover(cap, cursor))
            self.assertThat(
                second,
                failed(
                    AfterPreprocessing(
                        lambda f: f.value,
                        IsInstance(AlreadyRecovering),
                    ),
                ),
            )