diff options
Diffstat (limited to 'chromium/components/policy/test_support/policy_storage.cc')
-rw-r--r-- | chromium/components/policy/test_support/policy_storage.cc | 128 |
1 files changed, 128 insertions, 0 deletions
diff --git a/chromium/components/policy/test_support/policy_storage.cc b/chromium/components/policy/test_support/policy_storage.cc new file mode 100644 index 00000000000..bb6643c18e9 --- /dev/null +++ b/chromium/components/policy/test_support/policy_storage.cc @@ -0,0 +1,128 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "components/policy/test_support/policy_storage.h" +#include "base/big_endian.h" +#include "base/strings/strcat.h" +#include "base/strings/string_util.h" +#include "crypto/sha2.h" + +namespace policy { + +namespace { + +const char kPolicyKeySeparator[] = "/"; + +std::string GetPolicyKey(const std::string& policy_type, + const std::string& entity_id) { + if (entity_id.empty()) + return policy_type; + return base::StrCat({policy_type, kPolicyKeySeparator, entity_id}); +} + +} // namespace + +PolicyStorage::PolicyStorage() + : signature_provider_(std::make_unique<SignatureProvider>()) {} + +PolicyStorage::PolicyStorage(PolicyStorage&& policy_storage) = default; + +PolicyStorage& PolicyStorage::operator=(PolicyStorage&& policy_storage) = + default; + +PolicyStorage::~PolicyStorage() = default; + +std::string PolicyStorage::GetPolicyPayload( + const std::string& policy_type, + const std::string& entity_id) const { + auto it = policy_payloads_.find(GetPolicyKey(policy_type, entity_id)); + return it == policy_payloads_.end() ? std::string() : it->second; +} + +std::vector<std::string> PolicyStorage::GetEntityIdsForType( + const std::string& policy_type) { + std::string prefix = base::StrCat({policy_type, kPolicyKeySeparator}); + std::vector<std::string> ids; + const size_t prefix_length = prefix.length(); + for (const auto& [policy_key, payload] : policy_payloads_) { + if (base::StartsWith(policy_key, prefix)) + ids.push_back(policy_key.substr(prefix_length)); + } + return ids; +} + +void PolicyStorage::SetPolicyPayload(const std::string& policy_type, + const std::string& policy_payload) { + SetPolicyPayload(policy_type, std::string(), policy_payload); +} + +void PolicyStorage::SetPolicyPayload(const std::string& policy_type, + const std::string& entity_id, + const std::string& policy_payload) { + policy_payloads_[GetPolicyKey(policy_type, entity_id)] = policy_payload; +} + +std::string PolicyStorage::GetExternalPolicyPayload( + const std::string& policy_type, + const std::string& entity_id) { + std::string policy_key = GetPolicyKey(policy_type, entity_id); + return external_policy_payloads_.contains(policy_key) + ? external_policy_payloads_.at(policy_key) + : std::string(); +} + +void PolicyStorage::SetExternalPolicyPayload( + const std::string& policy_type, + const std::string& entity_id, + const std::string& policy_payload) { + external_policy_payloads_[GetPolicyKey(policy_type, entity_id)] = + policy_payload; +} + +void PolicyStorage::SetPsmEntry(const std::string& brand_serial_id, + const PolicyStorage::PsmEntry& psm_entry) { + psm_entries_[brand_serial_id] = psm_entry; +} + +const PolicyStorage::PsmEntry* PolicyStorage::GetPsmEntry( + const std::string& brand_serial_id) const { + auto it = psm_entries_.find(brand_serial_id); + if (it == psm_entries_.end()) + return nullptr; + return &it->second; +} + +void PolicyStorage::SetInitialEnrollmentState( + const std::string& brand_serial_id, + const PolicyStorage::InitialEnrollmentState& initial_enrollment_state) { + initial_enrollment_states_[brand_serial_id] = initial_enrollment_state; +} + +const PolicyStorage::InitialEnrollmentState* +PolicyStorage::GetInitialEnrollmentState( + const std::string& brand_serial_id) const { + auto it = initial_enrollment_states_.find(brand_serial_id); + if (it == initial_enrollment_states_.end()) + return nullptr; + return &it->second; +} + +std::vector<std::string> PolicyStorage::GetMatchingSerialHashes( + uint64_t modulus, + uint64_t remainder) const { + std::vector<std::string> hashes; + for (const auto& [serial, enrollment_state] : initial_enrollment_states_) { + uint64_t hash = 0UL; + uint8_t hash_bytes[sizeof(hash)]; + crypto::SHA256HashString(serial, hash_bytes, sizeof(hash)); + base::ReadBigEndian(hash_bytes, &hash); + if (hash % modulus == remainder) { + hashes.emplace_back(reinterpret_cast<const char*>(hash_bytes), + sizeof(hash)); + } + } + return hashes; +} + +} // namespace policy |