具有多个绑定的 SQLAlchemy - 动态选择绑定到查询


我有 4 个不同的数据库,每个数据库对应我的一位客户(医疗诊所),所有这些数据库都具有完全相同的结构。

在我的应用程序中,我有这样的模型Patient, Doctor, Appointment, etc.


class Patient(db.Model):
    __tablename__ = "patients"

    id = Column(Integer, primary_key=True)
    first_name = Column(String, index=True)
    last_name = Column(String, index=True)
    date_of_birth = Column(Date, index=True)


app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql://user:pass@localhost/main'
app.config['SQLALCHEMY_BINDS'] = {


  1. 我希望当我使用创建表时db.create_all()它将创建patients所有 4 个数据库中的表 (clinic1->clinic4)
  2. 我希望能够动态选择特定的绑定(在运行时),以便任何查询,例如Patient.query.filter().count()将针对所选的绑定数据库运行


with DbContext(bind='client1'):
    patients_count = Patient.query.filter().count()

# outside of the `with` context we are back to the default bind


patients_count = Patient.query.filter().count()





观察:db.create_all() calls self.get_tables_for_bind().

解决方案:覆盖SQLAlchemy get_tables_for_bind()支持'__all__'.

class MySQLAlchemy(SQLAlchemy):

    def get_tables_for_bind(self, bind=None):
        result = []
        for table in self.Model.metadata.tables.values():
            # if table.info.get('bind_key') == bind:
            if table.info.get('bind_key') == bind or (bind is not None and table.info.get('bind_key') == '__all__'):
        return result


# db = SQLAlchemy(app)  # Replace this
db = MySQLAlchemy(app)  # with this



观察:SignallingSession get_bind()负责确定绑定。


  1. 覆盖SignallingSession get_bind()从某些上下文中获取绑定密钥。
  2. 覆盖SQLAlchemy create_session()使用我们的自定义会话类。
  3. 支持上下文选择特定绑定db为了方便访问。
  4. 强制为表指定上下文'__all__'作为绑定键,通过覆盖SQLAlchemy get_binds()恢复默认引擎。
class MySignallingSession(SignallingSession):
    def __init__(self, db, *args, **kwargs):
        super().__init__(db, *args, **kwargs)
        self.db = db

    def get_bind(self, mapper=None, clause=None):
        if mapper is not None:
            info = getattr(mapper.persist_selectable, 'info', {})
            if info.get('bind_key') == '__all__':
                info['bind_key'] = self.db.context_bind_key
                    return super().get_bind(mapper=mapper, clause=clause)
                    info['bind_key'] = '__all__'
        return super().get_bind(mapper=mapper, clause=clause)

class MySQLAlchemy(SQLAlchemy):
    context_bind_key = None

    def context(self, bind=None):
        _context_bind_key = self.context_bind_key
            self.context_bind_key = bind
            self.context_bind_key = _context_bind_key

    def create_session(self, options):
        return orm.sessionmaker(class_=MySignallingSession, db=self, **options)

    def get_binds(self, app=None):
        binds = super().get_binds(app=app)
        # Restore default engine for table.info.get('bind_key') == '__all__'
        app = self.get_app(app)
        engine = self.get_engine(app, None)
        tables = self.get_tables_for_bind('__all__')
        binds.update(dict((table, engine) for table in tables))
        return binds

    def get_tables_for_bind(self, bind=None):
        result = []
        for table in self.Model.metadata.tables.values():
            if table.info.get('bind_key') == bind or (bind is not None and table.info.get('bind_key') == '__all__'):
        return result


class Patient(db.Model):
    __tablename__ = "patients"
    __bind_key__ = "__all__"  # Add this


with db.context(bind='clinic1'):
    db.session.flush()         # Flush in 'clinic1'
    with db.context(bind='clinic2'):
        patients_count = Patient.query.filter().count()
        print(patients_count)  # 0 in 'clinic2'
    patients_count = Patient.query.filter().count()
    print(patients_count)      # 1 in 'clinic1'




  • MySQL:
    • 绑定必须位于同一个 MySQL 实例中。否则,它必须是一个普通的列。
    • The foreign object in the default bind must already be committed.
      Otherwise, when inserting an object that references it, you will get this lock error:


  • SQLite:不强制执行跨数据库的外键。


# app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql://user:pass@localhost/main'

class PatientType(db.Model):
    __tablename__ = "patient_types"
    __table_args__ = {"schema": "main"}  # Add this, based on database name

    id = Column(Integer, primary_key=True)
    # ...

class Patient(db.Model):
    __tablename__ = "patients"
    __bind_key__ = "__all__"

    id = Column(Integer, primary_key=True)
    # ...
    # patient_type_id = Column(Integer, ForeignKey("patient_types.id"))     # Replace this
    patient_type_id = Column(Integer, ForeignKey("main.patient_types.id"))  # with this
    patient_type = relationship("PatientType")


patient_type = PatientType.query.first()
if not patient_type:
    patient_type = PatientType()
    db.session.commit()        # Commit to reference from other binds

with db.context(bind='clinic1'):
    db.session.flush()         # Flush in 'clinic1'
    with db.context(bind='clinic2'):
        patients_count = Patient.query.filter().count()
        print(patients_count)  # 0 in 'clinic2'
    patients_count = Patient.query.filter().count()
    print(patients_count)      # 1 in 'clinic1'

