diff --git a/csrc/fusion.cpp b/csrc/fusion.cpp index b20f1f23355..e33a3a39418 100644 --- a/csrc/fusion.cpp +++ b/csrc/fusion.cpp @@ -35,6 +35,212 @@ namespace nvfuser { +// Lock-free implementation methods for Fusion operations that mutate +// IrContainer state. These are called while the caller holds unique_lock +// on ir_container()->mutex_, avoiding self-deadlock on nested calls +// (e.g., removeVal → removeExpr). +struct Fusion::ContainerMutator { + static void removeExpr(Fusion* self, Expr* expr) { + self->ir_container()->assertInContainerImpl(expr, "Cannot remove expr "); + + for (auto* out : expr->outputs()) { + if (out->isA()) { + self->invalidateTvsAndUses(); + } + out->setDefinition(nullptr); + } + + for (auto* inp : expr->inputs()) { + inp->removeUse(expr); + if (inp->isA()) { + self->invalidateTvsAndUses(); + } + } + + auto* c = self->ir_container(); + auto expr_in_deque = std::ranges::find_if( + c->exprs_up_, [expr](std::unique_ptr& expr_up) { + return expr_up.get() == expr; + }); + NVF_ERROR( + expr_in_deque != c->exprs_up_.end(), + "Wanted to remove an expression but its unique ptr is missing."); + c->per_fusion_exprs_[self].erase(expr); + c->exprs_.erase(expr); + c->exprs_up_.erase(expr_in_deque); + } + + static void removeVal(Fusion* self, Val* val) { + self->ir_container()->assertInContainerImpl(val, "Cannot remove val "); + + // Don't remove cached special vals — they are lazily created singletons + if (val == self->zero_val_ || val == self->one_val_ || + val == self->true_val_ || val == self->false_val_ || + val == self->magic_zero_val_) { + return; + } + + NVF_CHECK( + !val->isFusionInput(), + "Cannot remove val as it is an input of the fusion."); + NVF_CHECK( + !val->isFusionOutput(), + "Cannot remove val as it is an output of the fusion."); + + if (Expr* orig = val->definition()) { + removeExpr(self, orig); + } + + // We must scan all per-fusion owned exprs (not just live uses) to find + // all expressions that reference this val, including dead code. + auto* c = self->ir_container(); + std::vector exprs_to_remove; + const auto& owned_exprs = c->per_fusion_exprs_[self]; + for (Expr* e : owned_exprs) { + if (!c->inContainerImpl(e)) { + continue; + } + if (std::find(e->inputs().begin(), e->inputs().end(), val) != + e->inputs().end()) { + exprs_to_remove.push_back(e); + } + } + for (auto e : exprs_to_remove) { + removeExpr(self, e); + } + + auto val_in_deque = std::ranges::find_if( + c->vals_up_, + [val](std::unique_ptr& val_up) { return val_up.get() == val; }); + NVF_ERROR( + val_in_deque != c->vals_up_.end(), + "Wanted to remove a value but its unique ptr is missing."); + c->per_fusion_vals_[self].erase(val); + c->vals_.erase(val); + c->vals_up_.erase(val_in_deque); + + self->invalidateTvsAndUses(); + } + + static void registerVal(Fusion* self, Val* val) { + if (self->ir_container()->inContainerImpl(val)) { + return; + } + + if (val->fusion()) { + NVF_CHECK( + val->fusion() == self, val, " was not found in the active fusion."); + } + + auto* c = self->ir_container(); + c->vals_up_.emplace_back(val); + c->vals_.insert(val); + c->per_fusion_vals_[self].insert(val); + val->setName(IrContainerPasskey(), self->getValName(val->vtype())); + } + + static void registerExpr(Fusion* self, Expr* expr) { + if (self->ir_container()->inContainerImpl(expr)) { + return; + } + + if (expr->fusion()) { + NVF_CHECK( + expr->fusion() == self, expr, " was not found in the active fusion."); + } + + auto* c = self->ir_container(); + c->exprs_up_.emplace_back(expr); + c->exprs_.insert(expr); + c->per_fusion_exprs_[self].insert(expr); + expr->setName(IrContainerPasskey(), self->getExprName()); + + for (Val* input : expr->inputs()) { + c->assertInContainerImpl(input, "Input to expr is invalid, "); + if (input->isA()) { + self->invalidateTvsAndUses(); + } else { + input->addUse(expr); + } + } + + const bool is_ssa = + !self->isA() && !self->isA(); + + for (Val* output : expr->outputs()) { + c->assertInContainerImpl(output, "Output to expr is invalid, "); + if (output->definition() != nullptr && is_ssa) { + removeExpr(self, output->definition()); + } + if (is_ssa || output->definition() == nullptr) { + output->setDefinition(expr); + if (output->isA()) { + self->invalidateTvsAndUses(); + } + } + } + } + + static int64_t numValsExcludingShortcuts(const Fusion* self) noexcept { + auto* c = self->ir_container(); + // Use direct field access. Avoids re-entering valsOwnedBy() which acquires + // shared_lock. + const auto it = c->per_fusion_vals_.find(self); + int64_t count = it != c->per_fusion_vals_.end() + ? static_cast(it->second.size()) + : 0; + count -= (self->zero_val_ != nullptr) + (self->one_val_ != nullptr) + + (self->true_val_ != nullptr) + (self->false_val_ != nullptr) + + (self->magic_zero_val_ != nullptr); + return count; + } + + static void removeStatementsCreatedAfter( + Fusion* self, + int64_t num_exprs_before, + int64_t num_vals_before) { + auto* c = self->ir_container(); + + // Remove expressions before values because we need to change Val::uses_. + while (std::ssize(c->per_fusion_exprs_[self]) > num_exprs_before) { + // Pop from global deque back — statements created by this Fusion during + // the guard scope are at the tail (LIFO invariant). + Expr* e = c->exprs_up_.back().get(); + NVF_ERROR( + c->per_fusion_exprs_[self].count(e) > 0, + "removeStatementsCreatedAfter: tail expr belongs to another Fusion"); + for (Val* in : e->inputs()) { + in->removeUse(e); + } + c->per_fusion_exprs_[self].erase(e); + c->exprs_.erase(e); + c->exprs_up_.pop_back(); + } + + while (numValsExcludingShortcuts(self) > num_vals_before) { + Val* v = c->vals_up_.back().get(); + NVF_ERROR( + c->per_fusion_vals_[self].count(v) > 0, + "removeStatementsCreatedAfter: tail val belongs to another Fusion"); + // Null out shortcut caches if they point to vals about to be destroyed + if (v == self->zero_val_) { + self->zero_val_ = nullptr; + } else if (v == self->one_val_) { + self->one_val_ = nullptr; + } else if (v == self->true_val_) { + self->true_val_ = nullptr; + } else if (v == self->false_val_) { + self->false_val_ = nullptr; + } else if (v == self->magic_zero_val_) { + self->magic_zero_val_ = nullptr; + } + c->per_fusion_vals_[self].erase(v); + c->vals_.erase(v); + c->vals_up_.pop_back(); + } + } +}; + size_t Fusion::hash() const { size_t hash = 0; @@ -182,6 +388,7 @@ void Fusion::swap(Fusion& a, Fusion& b) { if (a.ir_container_.get() == b.ir_container_.get()) { // Same container: directly swap per-Fusion tracking entries auto* c = a.ir_container_.get(); + std::unique_lock lock(c->mutex_); std::swap(c->per_fusion_vals_[&a], c->per_fusion_vals_[&b]); std::swap(c->per_fusion_exprs_[&a], c->per_fusion_exprs_[&b]); } else { @@ -396,143 +603,25 @@ void Fusion::clear() noexcept { } void Fusion::removeExpr(Expr* expr) { - assertInContainer(expr, "Cannot remove expr "); - // If we hit this error too frequently, we could lighten the restrictions so - // that removing something that doesn't exist simply does nothing. For now, - // we're going with the strictest model which errors. - - for (auto* out : expr->outputs()) { - if (out->isA()) { - invalidateTvsAndUses(); - } - out->setDefinition(nullptr); - } - - // Remove uses in inputs - for (auto* inp : expr->inputs()) { - // Note that if inp is a TensorView, this may call invalidateTvsAndUses - inp->removeUse(expr); - if (inp->isA()) { - invalidateTvsAndUses(); - } - } - - auto* c = ir_container(); - auto expr_in_deque = std::ranges::find_if( - c->exprs_up_, - [expr](std::unique_ptr& expr_up) { return expr_up.get() == expr; }); - NVF_ERROR( - expr_in_deque != c->exprs_up_.end(), - "Wanted to remove an expression but its unique ptr is missing."); - c->per_fusion_exprs_[this].erase(expr); - c->exprs_.erase(expr); - c->exprs_up_.erase(expr_in_deque); + std::unique_lock lock(ir_container()->mutex_); + ContainerMutator::removeExpr(this, expr); } void Fusion::removeVal(Val* val) { - assertInContainer(val, "Cannot remove val "); - - // Don't remove cached special vals — they are lazily created singletons - if (val == zero_val_ || val == one_val_ || val == true_val_ || - val == false_val_ || val == magic_zero_val_) { - return; - } - - NVF_CHECK( - !val->isFusionInput(), - "Cannot remove val as it is an input of the fusion."); - NVF_CHECK( - !val->isFusionOutput(), - "Cannot remove val as it is an output of the fusion."); - - if (Expr* orig = val->definition()) { - removeExpr(orig); - } - - // We previously first looped over val->uses() and removed them all from the - // Fusion. This seems correct at first glance, but it is incomplete since - // `val->uses()` actually only gives all live uses. When there is dead code in - // the Fusion that includes some uses of a val that is to be removed, we can - // wind up with an expression that holds an invalid pointer to the removed - // value in its inputs(). In https://github.com/NVIDIA/Fuser/issues/1270 this - // caused a segfault when the fusion was cloned since that will clone not only - // live objects but also these dangerous dangling dead ones. - // - // IMPORTANT: We must use unordered_exprs() instead of exprs() here. - // exprs() only returns Exprs reachable from terminating outputs, which means - // dead Exprs that still reference the Val won't be found and removed. - // This causes use-after-free when copying the Fusion later. - std::vector exprs_to_remove; - for (Expr* e : unordered_exprs()) { - if (!inContainer(e)) { - continue; - } - if (std::find(e->inputs().begin(), e->inputs().end(), val) != - e->inputs().end()) { - // Avoid removing until after we've looped through exprs_ - exprs_to_remove.push_back(e); - } - } - for (auto e : exprs_to_remove) { - removeExpr(e); - } - - auto* c = ir_container(); - auto val_in_deque = std::ranges::find_if( - c->vals_up_, - [val](std::unique_ptr& val_up) { return val_up.get() == val; }); - NVF_ERROR( - val_in_deque != c->vals_up_.end(), - "Wanted to remove a value but its unique ptr is missing."); - c->per_fusion_vals_[this].erase(val); - c->vals_.erase(val); - c->vals_up_.erase(val_in_deque); - - invalidateTvsAndUses(); + std::unique_lock lock(ir_container()->mutex_); + ContainerMutator::removeVal(this, val); } void Fusion::removeStatementsCreatedAfter( int64_t num_exprs_before, int64_t num_vals_before) { - auto* c = ir_container(); - - // Remove expressions before values because we need to change Val::uses_. - while (std::ssize(c->exprsOwnedBy(this)) > num_exprs_before) { - // Pop from global deque back — statements created by this Fusion during - // the guard scope are at the tail (LIFO invariant). - Expr* e = c->exprs_up_.back().get(); - NVF_ERROR( - c->per_fusion_exprs_[this].count(e) > 0, - "removeStatementsCreatedAfter: tail expr belongs to another Fusion"); - for (Val* in : e->inputs()) { - in->removeUse(e); - } - c->per_fusion_exprs_[this].erase(e); - c->exprs_.erase(e); - c->exprs_up_.pop_back(); - } + std::unique_lock lock(ir_container()->mutex_); + ContainerMutator::removeStatementsCreatedAfter( + this, num_exprs_before, num_vals_before); +} - while (numValsExcludingShortcuts() > num_vals_before) { - Val* v = c->vals_up_.back().get(); - NVF_ERROR( - c->per_fusion_vals_[this].count(v) > 0, - "removeStatementsCreatedAfter: tail val belongs to another Fusion"); - // Null out shortcut caches if they point to vals about to be destroyed - if (v == zero_val_) { - zero_val_ = nullptr; - } else if (v == one_val_) { - one_val_ = nullptr; - } else if (v == true_val_) { - true_val_ = nullptr; - } else if (v == false_val_) { - false_val_ = nullptr; - } else if (v == magic_zero_val_) { - magic_zero_val_ = nullptr; - } - c->per_fusion_vals_[this].erase(v); - c->vals_.erase(v); - c->vals_up_.pop_back(); - } +int64_t Fusion::numValsExcludingShortcuts() const noexcept { + return ContainerMutator::numValsExcludingShortcuts(this); } void Fusion::addInput(Val* input) { @@ -981,70 +1070,13 @@ void Fusion::assumeNonNegative(Val* val) { } void Fusion::registerVal(Val* val) { - if (inContainer(val)) { - return; - } - - if (val->fusion()) { - NVF_CHECK( - val->fusion() == this, val, " was not found in the active fusion."); - } - - auto* c = ir_container(); - c->vals_up_.emplace_back(val); - c->vals_.insert(val); - c->per_fusion_vals_[this].insert(val); - val->setName(IrContainerPasskey(), getValName(val->vtype())); + std::unique_lock lock(ir_container()->mutex_); + ContainerMutator::registerVal(this, val); } void Fusion::registerExpr(Expr* expr) { - if (inContainer(expr)) { - return; - } - - if (expr->fusion()) { - NVF_CHECK( - expr->fusion() == this, expr, " was not found in the active fusion."); - } - - auto* c = ir_container(); - c->exprs_up_.emplace_back(expr); - c->exprs_.insert(expr); - c->per_fusion_exprs_[this].insert(expr); - expr->setName(IrContainerPasskey(), getExprName()); - - for (Val* input : expr->inputs()) { - assertInContainer(input, "Input to expr is invalid, "); - // Don't just add this expr as a use of the input if it's a tensor as the - // whole fusion needs to be traversed to rebuild the usage lists - if (input->isA()) { - invalidateTvsAndUses(); - } else { - input->addUse(expr); - } - } - - // Kernel and host are non-ssa. This is mainly (maybe only) because of - // initialization expressions which would overwrite tensor view definitions. - const bool is_ssa = - !this->isA() && !this->isA(); - - for (Val* output : expr->outputs()) { - assertInContainer(output, "Output to expr is invalid, "); - if (output->definition() != nullptr && is_ssa) { - removeExpr(output->definition()); - } - if (is_ssa || output->definition() == nullptr) { - output->setDefinition(expr); - if (output->isA()) { - // Updating the definition might change the path to output TVs. - // If that happens, our definition-based traversal can change and - // introduce whole new branches, so we need to recompute the uses_ - // vector after setDefinition. - invalidateTvsAndUses(); - } - } - } + std::unique_lock lock(ir_container()->mutex_); + ContainerMutator::registerExpr(this, expr); } void Fusion::resetTvUses() { diff --git a/csrc/fusion.h b/csrc/fusion.h index 34be84be28d..8044a415305 100644 --- a/csrc/fusion.h +++ b/csrc/fusion.h @@ -570,13 +570,7 @@ class NVF_API Fusion : public PolymorphicBase { //! since they're singletons that should persist across StatementGuard scopes, //! this count excludes them so the LIFO pop-back in //! removeStatementsCreatedAfter correctly skips over them. - int64_t numValsExcludingShortcuts() const noexcept { - int64_t count = std::ssize(ir_container()->valsOwnedBy(this)); - count -= (zero_val_ != nullptr) + (one_val_ != nullptr) + - (true_val_ != nullptr) + (false_val_ != nullptr) + - (magic_zero_val_ != nullptr); - return count; - } + int64_t numValsExcludingShortcuts() const noexcept; // Shortcut values (frequently used constants) Val* zeroVal(); @@ -661,6 +655,10 @@ class NVF_API Fusion : public PolymorphicBase { inline static const std::string exact_mappings_key = "exact_mappings"; std::shared_ptr ir_container_; + // PIMPL for lock-free mutation methods (defined in fusion.cpp) + struct ContainerMutator; + friend struct ContainerMutator; + Val* zero_val_ = nullptr; Val* one_val_ = nullptr; Val* true_val_ = nullptr; diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index 5c0b2e0f5a6..d4bdc54d41c 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -9,14 +9,12 @@ #include "instrumentation.h" #include "ir/base_nodes.h" -#include "ir/builder.h" -#include "ir/cloner.h" -#include "ir/internal_nodes.h" namespace nvfuser { //! Return values in insertion order const std::deque IrContainer::deterministic_vals() const noexcept { + std::shared_lock lock(mutex_); std::deque vals_deque; std::ranges::transform( vals_up_, @@ -27,6 +25,7 @@ const std::deque IrContainer::deterministic_vals() const noexcept { //! Return expression in insertion order const std::deque IrContainer::deterministic_exprs() const noexcept { + std::shared_lock lock(mutex_); std::deque exprs_deque; std::ranges::transform( exprs_up_, @@ -38,6 +37,7 @@ const std::deque IrContainer::deterministic_exprs() const noexcept { //! Return mapping from value to integer id const std::unordered_map IrContainer::deterministic_vals_map() const noexcept { + std::shared_lock lock(mutex_); std::unordered_map vals_map; int64_t count = 0; std::ranges::transform( @@ -52,6 +52,7 @@ const std::unordered_map IrContainer::deterministic_vals_map() //! Return mapping from expression to integer id const std::unordered_map IrContainer::deterministic_exprs_map() const noexcept { + std::shared_lock lock(mutex_); std::unordered_map exprs_map; int64_t count = 0; std::ranges::transform( @@ -63,57 +64,16 @@ const std::unordered_map IrContainer::deterministic_exprs_map() return exprs_map; } -void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept { - FUSER_PERF_SCOPE("Fusion swap"); - - // Swap the content - std::swap(a.vals_up_, b.vals_up_); - std::swap(a.vals_, b.vals_); - - std::swap(a.exprs_up_, b.exprs_up_); - std::swap(a.exprs_, b.exprs_); - - std::swap(a.val_type_name_map_, b.val_type_name_map_); - std::swap(a.expr_name_counter_, b.expr_name_counter_); - - std::swap(a.per_fusion_vals_, b.per_fusion_vals_); - std::swap(a.per_fusion_exprs_, b.per_fusion_exprs_); -} - -IrCloner IrContainer::copy( - const IrContainer* from, - IrContainer* to, - Fusion* dest_fusion) { - to->clear(); - - IrCloner ir_cloner(dest_fusion); - - // Copy values in deterministic order - for (auto val : from->deterministic_vals()) { - if (from->vals().count(val) > 0) { - to->vals_.insert(ir_cloner.clone(val)); - } - } - - // Copy expressions in deterministic order - for (auto expr : from->deterministic_exprs()) { - if (from->unordered_exprs().count(expr) > 0) { - to->exprs_.insert(ir_cloner.clone(expr)); - } - } - - to->val_type_name_map_ = from->val_type_name_map_; - to->expr_name_counter_ = from->expr_name_counter_; - - return ir_cloner; -} - IrContainer::IrContainer() = default; IrContainer::~IrContainer() { clear(); } +// Note: clear() does not acquire mutex_. It is only called from the +// destructor and Fusion::copy(), both of which guarantee exclusive access. +// This assumption must be revisited in Phase 3 when containers may be shared +// across threads. void IrContainer::clear() noexcept { FUSER_PERF_SCOPE("IrContainer clear"); vals_.clear(); @@ -127,6 +87,11 @@ void IrContainer::clear() noexcept { } bool IrContainer::inContainer(const Statement* const_stmt) const { + std::shared_lock lock(mutex_); + return inContainerImpl(const_stmt); +} + +bool IrContainer::inContainerImpl(const Statement* const_stmt) const { // We don't use dynamic_cast here because `const_stmt` may be an invalid // pointer. Specifically a pointer to a Statement owned by another container // that has been freed. @@ -158,33 +123,67 @@ bool IrContainer::inContainer(const Statement* const_stmt) const { return true; } +void IrContainer::assertInContainerImpl( + const Statement* stmt, + const std::string& msg) const { + NVF_CHECK( + inContainerImpl(stmt), msg, " it was not found in the active container."); +} + +const std::unordered_set& IrContainer::unordered_exprs() const noexcept { + std::shared_lock lock(mutex_); + return exprs_; +} + +const std::unordered_set& IrContainer::vals() const noexcept { + std::shared_lock lock(mutex_); + return vals_; +} + +int64_t IrContainer::numExprs() const noexcept { + std::shared_lock lock(mutex_); + return std::ssize(exprs_); +} + +int64_t IrContainer::numVals() const noexcept { + std::shared_lock lock(mutex_); + return std::ssize(vals_up_); +} + void IrContainer::addFusion(Fusion* fusion) { + std::unique_lock lock(mutex_); sharing_fusions_.insert(fusion); } void IrContainer::removeFusion(Fusion* fusion) { + std::unique_lock lock(mutex_); sharing_fusions_.erase(fusion); } void IrContainer::transferFusion(Fusion* from, Fusion* to) { + std::unique_lock lock(mutex_); sharing_fusions_.erase(from); sharing_fusions_.insert(to); } size_t IrContainer::sharingCount() const { + std::shared_lock lock(mutex_); return sharing_fusions_.size(); } bool IrContainer::hasMultipleFusions() const { + std::shared_lock lock(mutex_); return sharing_fusions_.size() > 1; } const std::unordered_set& IrContainer::sharingFusions() const { + std::shared_lock lock(mutex_); return sharing_fusions_; } const std::unordered_set& IrContainer::valsOwnedBy( const Fusion* fusion) const { + std::shared_lock lock(mutex_); static const std::unordered_set empty; auto it = per_fusion_vals_.find(fusion); return it != per_fusion_vals_.end() ? it->second : empty; @@ -192,6 +191,7 @@ const std::unordered_set& IrContainer::valsOwnedBy( const std::unordered_set& IrContainer::exprsOwnedBy( const Fusion* fusion) const { + std::shared_lock lock(mutex_); static const std::unordered_set empty; auto it = per_fusion_exprs_.find(fusion); return it != per_fusion_exprs_.end() ? it->second : empty; @@ -200,6 +200,7 @@ const std::unordered_set& IrContainer::exprsOwnedBy( void IrContainer::transferStatementOwnership( const Fusion* from, const Fusion* to) { + std::unique_lock lock(mutex_); auto vals_it = per_fusion_vals_.find(from); if (vals_it != per_fusion_vals_.end()) { auto& to_vals = per_fusion_vals_[to]; @@ -216,6 +217,7 @@ void IrContainer::transferStatementOwnership( } void IrContainer::removeStatementsOwnedBy(const Fusion* fusion) { + std::unique_lock lock(mutex_); auto vals_it = per_fusion_vals_.find(fusion); if (vals_it != per_fusion_vals_.end()) { const auto& owned = vals_it->second; @@ -245,6 +247,7 @@ void IrContainer::removeStatementsOwnedBy(const Fusion* fusion) { std::deque IrContainer::deterministicValsOwnedBy( const Fusion* fusion) const noexcept { + std::shared_lock lock(mutex_); std::deque result; auto it = per_fusion_vals_.find(fusion); if (it == per_fusion_vals_.end()) { @@ -261,6 +264,7 @@ std::deque IrContainer::deterministicValsOwnedBy( std::deque IrContainer::deterministicExprsOwnedBy( const Fusion* fusion) const noexcept { + std::shared_lock lock(mutex_); std::deque result; auto it = per_fusion_exprs_.find(fusion); if (it == per_fusion_exprs_.end()) { @@ -277,6 +281,7 @@ std::deque IrContainer::deterministicExprsOwnedBy( std::unordered_map IrContainer::deterministicValsMapOwnedBy( const Fusion* fusion) const noexcept { + std::shared_lock lock(mutex_); std::unordered_map result; auto it = per_fusion_vals_.find(fusion); if (it == per_fusion_vals_.end()) { @@ -294,6 +299,7 @@ std::unordered_map IrContainer::deterministicValsMapOwnedBy( std::unordered_map IrContainer::deterministicExprsMapOwnedBy( const Fusion* fusion) const noexcept { + std::shared_lock lock(mutex_); std::unordered_map result; auto it = per_fusion_exprs_.find(fusion); if (it == per_fusion_exprs_.end()) { diff --git a/csrc/ir/container.h b/csrc/ir/container.h index ed0b4504840..a9555ae3305 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -8,6 +8,7 @@ #pragma once #include +#include #include #include @@ -68,34 +69,21 @@ class IrContainer { //! Return the set of Exprs registered with this fusion. Warning: This will //! return exprs outside inputs/outputs, so can be unsafe for use with //! segmented fusions. - const std::unordered_set& unordered_exprs() const noexcept { - return exprs_; - } + const std::unordered_set& unordered_exprs() const noexcept; //! Return the set of Vals registered with this fusion - const std::unordered_set& vals() const noexcept { - return vals_; - } + const std::unordered_set& vals() const noexcept; - int64_t numExprs() const noexcept { - return std::ssize(exprs_); - } + int64_t numExprs() const noexcept; - int64_t numVals() const noexcept { - return std::ssize(vals_up_); - } + int64_t numVals() const noexcept; protected: - static IrCloner copy( - const IrContainer* from, - IrContainer* to, - Fusion* dest_fusion); - - static void swap(IrContainer& a, IrContainer& b) noexcept; - - // Let Fusion access IrContainer::clear() + // Let Fusion access IrContainer internals (mutex_, fields, Impl helpers) friend class Fusion; + mutable std::shared_mutex mutex_; + StmtNameType getValName(ValType vtype) { if (val_type_name_map_.find(vtype) == val_type_name_map_.end()) { val_type_name_map_[vtype] = 0; @@ -153,6 +141,11 @@ class IrContainer { const Fusion* fusion) const noexcept; private: + // Lock-free implementations for use by Fusion (which holds mutex_ directly) + bool inContainerImpl(const Statement* stmt) const; + void assertInContainerImpl(const Statement* stmt, const std::string& msg) + const; + std::unordered_set sharing_fusions_; std::unordered_map> per_fusion_vals_; std::unordered_map> diff --git a/csrc/runtime/fusion_kernel_runtime.cpp b/csrc/runtime/fusion_kernel_runtime.cpp index a76bb19d563..070fc1b27eb 100644 --- a/csrc/runtime/fusion_kernel_runtime.cpp +++ b/csrc/runtime/fusion_kernel_runtime.cpp @@ -26,9 +26,6 @@ namespace nvfuser { -// TODO: Remove when std::shared_mutex is added to IrContainer. -constexpr bool kPhase2DisableParallelCompile = true; - namespace { // Replace CUDA tensor with Meta tensor because storing tensors can cause // out-of-memory issues. Other arguments are returned as-is. @@ -439,8 +436,7 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) { try { for (const auto& [group_to_run, group_runtime_inputs] : zip(runtime_workspace_.group_run_order, all_runtime_inputs)) { - if (num_groups == 1 || kPhase2DisableParallelCompile || - isOptionDisabled(DisableOption::ParallelCompile)) { + if (num_groups == 1 || isOptionDisabled(DisableOption::ParallelCompile)) { compileKernel(group_runtime_inputs, group_to_run); } else { // launch compileKernel thread here @@ -474,8 +470,7 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) { throw; } - if (num_groups != 1 && !kPhase2DisableParallelCompile && - !isOptionDisabled(DisableOption::ParallelCompile)) { + if (num_groups != 1 && !isOptionDisabled(DisableOption::ParallelCompile)) { // Wait until all segments finish compiling getThreadPool()->waitWorkComplete(); NVF_ERROR(