// Copyright (C) 2016 The Qt Company Ltd. // SPDX-License-Identifier: LicenseRef-Qt-Commercial OR GPL-3.0-only WITH Qt-GPL-exception-1.0 #pragma once #include "sqlite3_fwd.h" #include "sqliteglobal.h" #include "sqliteblob.h" #include "sqliteexception.h" #include "sqliteids.h" #include "sqlitetransaction.h" #include "sqlitevalue.h" #include #include #include #include #include #include #include #include #include using std::int64_t; namespace Sqlite { class Database; class DatabaseBackend; enum class Type : char { Invalid, Integer, Float, Text, Blob, Null }; template constexpr static std::underlying_type_t to_underlying(Enumeration enumeration) noexcept { static_assert(std::is_enum_v, "to_underlying expect an enumeration"); return static_cast>(enumeration); } class SQLITE_EXPORT BaseStatement { public: using Database = ::Sqlite::Database; explicit BaseStatement(Utils::SmallStringView sqlStatement, Database &database); BaseStatement(const BaseStatement &) = delete; BaseStatement &operator=(const BaseStatement &) = delete; static void deleteCompiledStatement(sqlite3_stmt *m_compiledStatement); bool next() const; void step() const; void reset() const noexcept; Type fetchType(int column) const; int fetchIntValue(int column) const; long fetchLongValue(int column) const; long long fetchLongLongValue(int column) const; double fetchDoubleValue(int column) const; Utils::SmallStringView fetchSmallStringViewValue(int column) const; ValueView fetchValueView(int column) const; BlobView fetchBlobValue(int column) const; template Type fetchValue(int column) const; void bindNull(int index); void bind(int index, NullValue); void bind(int index, int value); void bind(int index, long long value); void bind(int index, double value); void bind(int index, void *pointer); void bind(int index, Utils::span values); void bind(int index, Utils::span values); void bind(int index, Utils::span values); void bind(int index, Utils::span values); void bind(int index, Utils::SmallStringView value); void bind(int index, const Value &value); void bind(int index, ValueView value); void bind(int index, BlobView blobView); template> void bind(int index, Type id) { if (id) bind(index, id.internalId()); else bindNull(index); } template, bool> = true> void bind(int index, Enumeration enumeration) { bind(index, to_underlying(enumeration)); } void bind(int index, uint value) { bind(index, static_cast(value)); } void bind(int index, long value) { bind(index, static_cast(value)); } void prepare(Utils::SmallStringView sqlStatement); void waitForUnlockNotify() const; sqlite3 *sqliteDatabaseHandle() const; [[noreturn]] void checkForStepError(int resultCode) const; [[noreturn]] void checkForPrepareError(int resultCode) const; [[noreturn]] void checkForBindingError(int resultCode) const; void setIfIsReadyToFetchValues(int resultCode) const; void checkBindingName(int index) const; void checkBindingParameterCount(int bindingParameterCount) const; void checkColumnCount(int columnCount) const; bool isReadOnlyStatement() const; [[noreturn]] void throwStatementIsBusy(const char *whatHasHappened) const; [[noreturn]] void throwStatementHasError(const char *whatHasHappened) const; [[noreturn]] void throwStatementIsMisused(const char *whatHasHappened) const; [[noreturn]] void throwInputOutputError(const char *whatHasHappened) const; [[noreturn]] void throwConstraintPreventsModification(const char *whatHasHappened) const; [[noreturn]] void throwNoValuesToFetch(const char *whatHasHappened) const; [[noreturn]] void throwInvalidColumnFetched(const char *whatHasHappened) const; [[noreturn]] void throwBindingIndexIsOutOfRange(const char *whatHasHappened) const; [[noreturn]] void throwWrongBingingName(const char *whatHasHappened) const; [[noreturn]] void throwUnknowError(const char *whatHasHappened) const; [[noreturn]] void throwBingingTooBig(const char *whatHasHappened) const; [[noreturn]] void throwTooBig(const char *whatHasHappened) const; [[noreturn]] void throwSchemaChangeError(const char *whatHasHappened) const; [[noreturn]] void throwCannotWriteToReadOnlyConnection(const char *whatHasHappened) const; [[noreturn]] void throwProtocolError(const char *whatHasHappened) const; [[noreturn]] void throwDatabaseExceedsMaximumFileSize(const char *whatHasHappened) const; [[noreturn]] void throwDataTypeMismatch(const char *whatHasHappened) const; [[noreturn]] void throwConnectionIsLocked(const char *whatHasHappened) const; [[noreturn]] void throwExecutionInterrupted(const char *whatHasHappened) const; [[noreturn]] void throwDatabaseIsCorrupt(const char *whatHasHappened) const; [[noreturn]] void throwCannotOpen(const char *whatHasHappened) const; QString columnName(int column) const; Database &database() const; protected: ~BaseStatement() = default; private: std::unique_ptr m_compiledStatement; Database &m_database; }; template <> SQLITE_EXPORT int BaseStatement::fetchValue(int column) const; template <> SQLITE_EXPORT long BaseStatement::fetchValue(int column) const; template <> SQLITE_EXPORT long long BaseStatement::fetchValue(int column) const; template <> SQLITE_EXPORT double BaseStatement::fetchValue(int column) const; extern template SQLITE_EXPORT Utils::SmallStringView BaseStatement::fetchValue(int column) const; extern template SQLITE_EXPORT Utils::SmallString BaseStatement::fetchValue(int column) const; extern template SQLITE_EXPORT Utils::PathString BaseStatement::fetchValue(int column) const; template class StatementImplementation : public BaseStatement { struct Resetter; public: using BaseStatement::BaseStatement; void execute() { Resetter resetter{this}; BaseStatement::next(); } template void bindValues(const ValueType &...values) { static_assert(BindParameterCount == sizeof...(values), "Wrong binding parameter count!"); int index = 0; (BaseStatement::bind(++index, values), ...); } template void write(const ValueType&... values) { Resetter resetter{this}; bindValues(values...); BaseStatement::next(); } template auto values(std::size_t reserveSize, const QueryTypes &...queryValues) { Resetter resetter{this}; std::vector resultValues; resultValues.reserve(std::max(reserveSize, m_maximumResultCount)); bindValues(queryValues...); while (BaseStatement::next()) emplaceBackValues(resultValues); setMaximumResultCount(resultValues.size()); return resultValues; } template auto value(const QueryTypes &...queryValues) { Resetter resetter{this}; ResultType resultValue{}; bindValues(queryValues...); if (BaseStatement::next()) resultValue = createValue(); return resultValue; } template auto optionalValue(const QueryTypes &...queryValues) { Resetter resetter{this}; std::optional resultValue; bindValues(queryValues...); if (BaseStatement::next()) resultValue = createOptionalValue>(); return resultValue; } template static auto toValue(Utils::SmallStringView sqlStatement, Database &database) { StatementImplementation statement(sqlStatement, database); statement.checkColumnCount(1); statement.next(); return statement.template fetchValue(0); } template void readCallback(Callable &&callable, const QueryTypes &...queryValues) { Resetter resetter{this}; bindValues(queryValues...); while (BaseStatement::next()) { auto control = callCallable(callable); if (control == CallbackControl::Abort) break; } } template void readTo(Container &container, const QueryTypes &...queryValues) { Resetter resetter{this}; bindValues(queryValues...); while (BaseStatement::next()) emplaceBackValues(container); } template auto range(const QueryTypes &...queryValues) { return SqliteResultRange{*this, queryValues...}; } template auto rangeWithTransaction(const QueryTypes &...queryValues) { return SqliteResultRangeWithTransaction{*this, queryValues...}; } template class BaseSqliteResultRange { public: class SqliteResultIteratator { public: using iterator_category = std::input_iterator_tag; using difference_type = int; using value_type = ResultType; using pointer = ResultType *; using reference = ResultType &; SqliteResultIteratator(StatementImplementation &statement) : m_statement{statement} , m_hasNext{m_statement.next()} {} SqliteResultIteratator(StatementImplementation &statement, bool hasNext) : m_statement{statement} , m_hasNext{hasNext} {} SqliteResultIteratator &operator++() { m_hasNext = m_statement.next(); return *this; } void operator++(int) { m_hasNext = m_statement.next(); } friend bool operator==(const SqliteResultIteratator &first, const SqliteResultIteratator &second) { return first.m_hasNext == second.m_hasNext; } friend bool operator!=(const SqliteResultIteratator &first, const SqliteResultIteratator &second) { return !(first == second); } value_type operator*() const { return m_statement.createValue(); } private: StatementImplementation &m_statement; bool m_hasNext = false; }; using value_type = ResultType; using iterator = SqliteResultIteratator; using const_iterator = iterator; template BaseSqliteResultRange(StatementImplementation &statement) : m_statement{statement} { } BaseSqliteResultRange(BaseSqliteResultRange &) = delete; BaseSqliteResultRange &operator=(BaseSqliteResultRange &) = delete; BaseSqliteResultRange(BaseSqliteResultRange &&other) : m_statement{std::move(other.resetter)} {} BaseSqliteResultRange &operator=(BaseSqliteResultRange &&) = delete; iterator begin() & { return iterator{m_statement}; } iterator end() & { return iterator{m_statement, false}; } const_iterator begin() const & { return iterator{m_statement}; } const_iterator end() const & { return iterator{m_statement, false}; } private: StatementImplementation &m_statement; }; template class SqliteResultRange : public BaseSqliteResultRange { public: template SqliteResultRange(StatementImplementation &statement, const QueryTypes &...queryValues) : BaseSqliteResultRange{statement} , resetter{&statement} { statement.bindValues(queryValues...); } private: Resetter resetter; }; template class SqliteResultRangeWithTransaction : public BaseSqliteResultRange { public: template SqliteResultRangeWithTransaction(StatementImplementation &statement, const QueryTypes &...queryValues) : BaseSqliteResultRange{statement} , m_transaction{statement.database()} , resetter{&statement} { statement.bindValues(queryValues...); } ~SqliteResultRangeWithTransaction() { resetter.reset(); if (!std::uncaught_exceptions()) { m_transaction.commit(); } } private: DeferredTransaction m_transaction; Resetter resetter; }; protected: ~StatementImplementation() = default; private: struct Resetter { Resetter(StatementImplementation *statement) : statement(statement) { if (statement && !statement->database().isLocked()) throw DatabaseIsNotLocked{"Database connection is not locked!"}; } Resetter(Resetter &) = delete; Resetter &operator=(Resetter &) = delete; Resetter(Resetter &&other) : statement{std::exchange(other.statement, nullptr)} {} void reset() { if (statement) statement->reset(); statement = nullptr; } ~Resetter() noexcept { reset(); } StatementImplementation *statement; }; struct ValueGetter { ValueGetter(StatementImplementation &statement, int column) : statement(statement) , column(column) {} explicit operator bool() const { return statement.fetchIntValue(column); } operator int() const { return statement.fetchIntValue(column); } operator long() const { return statement.fetchLongValue(column); } operator long long() const { return statement.fetchLongLongValue(column); } operator double() const { return statement.fetchDoubleValue(column); } operator Utils::SmallStringView() { return statement.fetchSmallStringViewValue(column); } operator BlobView() { return statement.fetchBlobValue(column); } operator ValueView() { return statement.fetchValueView(column); } template> constexpr operator ConversionType() { if (statement.fetchType(column) == Type::Integer) { if constexpr (std::is_same_v) return ConversionType::create(statement.fetchIntValue(column)); else return ConversionType::create(statement.fetchLongLongValue(column)); } return ConversionType{}; } template, bool> = true> constexpr operator Enumeration() { return static_cast(statement.fetchLongLongValue(column)); } StatementImplementation &statement; int column; }; template void emplaceBackValues(ContainerType &container, std::integer_sequence) { container.emplace_back(ValueGetter(*this, ColumnIndices)...); } template void emplaceBackValues(ContainerType &container) { emplaceBackValues(container, std::make_integer_sequence{}); } template ResultOptionalType createOptionalValue(std::integer_sequence) { return ResultOptionalType(std::in_place, ValueGetter(*this, ColumnIndices)...); } template ResultOptionalType createOptionalValue() { return createOptionalValue(std::make_integer_sequence{}); } template ResultType createValue(std::integer_sequence) { return ResultType(ValueGetter(*this, ColumnIndices)...); } template ResultType createValue() { return createValue(std::make_integer_sequence{}); } template CallbackControl callCallable(Callable &&callable, std::integer_sequence) { return std::invoke(callable, ValueGetter(*this, ColumnIndices)...); } template CallbackControl callCallable(Callable &&callable) { return callCallable(callable, std::make_integer_sequence{}); } void setMaximumResultCount(std::size_t count) { m_maximumResultCount = std::max(m_maximumResultCount, count); } public: std::size_t m_maximumResultCount = 0; }; } // namespace Sqlite