diff options
Diffstat (limited to 'chromium/third_party/webrtc/media/sctp/sctp_transport.cc')
-rw-r--r-- | chromium/third_party/webrtc/media/sctp/sctp_transport.cc | 158 |
1 files changed, 94 insertions, 64 deletions
diff --git a/chromium/third_party/webrtc/media/sctp/sctp_transport.cc b/chromium/third_party/webrtc/media/sctp/sctp_transport.cc index 13e7db49ce5..caf52dbbade 100644 --- a/chromium/third_party/webrtc/media/sctp/sctp_transport.cc +++ b/chromium/third_party/webrtc/media/sctp/sctp_transport.cc @@ -27,6 +27,7 @@ constexpr int kSctpSuccessReturn = 1; #include <stdio.h> #include <usrsctp.h> +#include <functional> #include <memory> #include <unordered_map> @@ -79,58 +80,8 @@ enum { PPID_TEXT_LAST = 51 }; -// Maps SCTP transport ID to SctpTransport object, necessary in send threshold -// callback and outgoing packet callback. -// TODO(crbug.com/1076703): Remove once the underlying problem is fixed or -// workaround is provided in usrsctp. -class SctpTransportMap { - public: - SctpTransportMap() = default; - - // Assigns a new unused ID to the following transport. - uintptr_t Register(cricket::SctpTransport* transport) { - webrtc::MutexLock lock(&lock_); - // usrsctp_connect fails with a value of 0... - if (next_id_ == 0) { - ++next_id_; - } - // In case we've wrapped around and need to find an empty spot from a - // removed transport. Assumes we'll never be full. - while (map_.find(next_id_) != map_.end()) { - ++next_id_; - if (next_id_ == 0) { - ++next_id_; - } - }; - map_[next_id_] = transport; - return next_id_++; - } - - // Returns true if found. - bool Deregister(uintptr_t id) { - webrtc::MutexLock lock(&lock_); - return map_.erase(id) > 0; - } - - cricket::SctpTransport* Retrieve(uintptr_t id) const { - webrtc::MutexLock lock(&lock_); - auto it = map_.find(id); - if (it == map_.end()) { - return nullptr; - } - return it->second; - } - - private: - mutable webrtc::Mutex lock_; - - uintptr_t next_id_ RTC_GUARDED_BY(lock_) = 0; - std::unordered_map<uintptr_t, cricket::SctpTransport*> map_ - RTC_GUARDED_BY(lock_); -}; - // Should only be modified by UsrSctpWrapper. -ABSL_CONST_INIT SctpTransportMap* g_transport_map_ = nullptr; +ABSL_CONST_INIT cricket::SctpTransportMap* g_transport_map_ = nullptr; // Helper for logging SCTP messages. #if defined(__GNUC__) @@ -256,6 +207,83 @@ sctp_sendv_spa CreateSctpSendParams(const cricket::SendDataParams& params) { namespace cricket { +// Maps SCTP transport ID to SctpTransport object, necessary in send threshold +// callback and outgoing packet callback. It also provides a facility to +// safely post a task to an SctpTransport's network thread from another thread. +class SctpTransportMap { + public: + SctpTransportMap() = default; + + // Assigns a new unused ID to the following transport. + uintptr_t Register(cricket::SctpTransport* transport) { + webrtc::MutexLock lock(&lock_); + // usrsctp_connect fails with a value of 0... + if (next_id_ == 0) { + ++next_id_; + } + // In case we've wrapped around and need to find an empty spot from a + // removed transport. Assumes we'll never be full. + while (map_.find(next_id_) != map_.end()) { + ++next_id_; + if (next_id_ == 0) { + ++next_id_; + } + }; + map_[next_id_] = transport; + return next_id_++; + } + + // Returns true if found. + bool Deregister(uintptr_t id) { + webrtc::MutexLock lock(&lock_); + return map_.erase(id) > 0; + } + + // Must be called on the transport's network thread to protect against + // simultaneous deletion/deregistration of the transport; if that's not + // guaranteed, use ExecuteWithLock. + SctpTransport* Retrieve(uintptr_t id) const { + webrtc::MutexLock lock(&lock_); + SctpTransport* transport = RetrieveWhileHoldingLock(id); + if (transport) { + RTC_DCHECK_RUN_ON(transport->network_thread()); + } + return transport; + } + + // Posts |action| to the network thread of the transport identified by |id| + // and returns true if found, all while holding a lock to protect against the + // transport being simultaneously deleted/deregistered, or returns false if + // not found. + bool PostToTransportThread(uintptr_t id, + std::function<void(SctpTransport*)> action) const { + webrtc::MutexLock lock(&lock_); + SctpTransport* transport = RetrieveWhileHoldingLock(id); + if (!transport) { + return false; + } + transport->invoker_.AsyncInvoke<void>( + RTC_FROM_HERE, transport->network_thread_, [transport, action]() { + action(transport); }); + return true; + } + + private: + SctpTransport* RetrieveWhileHoldingLock(uintptr_t id) const + RTC_EXCLUSIVE_LOCKS_REQUIRED(lock_) { + auto it = map_.find(id); + if (it == map_.end()) { + return nullptr; + } + return it->second; + } + + mutable webrtc::Mutex lock_; + + uintptr_t next_id_ RTC_GUARDED_BY(lock_) = 0; + std::unordered_map<uintptr_t, SctpTransport*> map_ RTC_GUARDED_BY(lock_); +}; + // Handles global init/deinit, and mapping from usrsctp callbacks to // SctpTransport calls. class SctpTransport::UsrSctpWrapper { @@ -357,14 +385,6 @@ class SctpTransport::UsrSctpWrapper { << "OnSctpOutboundPacket called after usrsctp uninitialized?"; return EINVAL; } - SctpTransport* transport = - g_transport_map_->Retrieve(reinterpret_cast<uintptr_t>(addr)); - if (!transport) { - RTC_LOG(LS_ERROR) - << "OnSctpOutboundPacket: Failed to get transport for socket ID " - << addr; - return EINVAL; - } RTC_LOG(LS_VERBOSE) << "global OnSctpOutboundPacket():" "addr: " << addr << "; length: " << length @@ -372,13 +392,23 @@ class SctpTransport::UsrSctpWrapper { << "; set_df: " << rtc::ToHex(set_df); VerboseLogPacket(data, length, SCTP_DUMP_OUTBOUND); + // Note: We have to copy the data; the caller will delete it. rtc::CopyOnWriteBuffer buf(reinterpret_cast<uint8_t*>(data), length); - // TODO(deadbeef): Why do we need an AsyncInvoke here? We're already on the - // right thread and don't need to unwind the stack. - transport->invoker_.AsyncInvoke<void>( - RTC_FROM_HERE, transport->network_thread_, - rtc::Bind(&SctpTransport::OnPacketFromSctpToNetwork, transport, buf)); + + // PostsToTransportThread protects against the transport being + // simultaneously deregistered/deleted, since this callback may come from + // the SCTP timer thread and thus race with the network thread. + bool found = g_transport_map_->PostToTransportThread( + reinterpret_cast<uintptr_t>(addr), [buf](SctpTransport* transport) { + transport->OnPacketFromSctpToNetwork(buf); + }); + if (!found) { + RTC_LOG(LS_ERROR) + << "OnSctpOutboundPacket: Failed to get transport for socket ID " + << addr; + return EINVAL; + } return 0; } |