feat: multi-company switcher backend (helper, context processor, switch endpoint, session init)
- Add helpers/company_context.py with get_active_company_id() fallback logic - Add inject_company_context() context processor to app.py (user_companies, active_company, has_multiple_companies) - Add /api/switch-company/<id> POST endpoint in public blueprint - Set session['active_company_id'] on login (both standard and 2FA paths) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
2bf5c780e2
commit
1598f93c58
45
app.py
45
app.py
@ -354,6 +354,51 @@ def inject_audit_access():
|
||||
return dict(is_audit_owner=is_audit_owner())
|
||||
|
||||
|
||||
@app.context_processor
|
||||
def inject_company_context():
|
||||
"""Inject multi-company context into all templates."""
|
||||
if not current_user.is_authenticated or not current_user.company_id:
|
||||
return {}
|
||||
|
||||
from database import UserCompany
|
||||
from helpers.company_context import get_active_company_id
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
user_companies = db.query(UserCompany).filter_by(
|
||||
user_id=current_user.id
|
||||
).order_by(UserCompany.is_primary.desc(), UserCompany.created_at.asc()).all()
|
||||
|
||||
# Eager-load company objects while session is open
|
||||
for uc in user_companies:
|
||||
_ = uc.company.name if uc.company else None
|
||||
|
||||
active_cid = get_active_company_id()
|
||||
|
||||
# Validate active_company_id is still valid for this user
|
||||
valid_ids = {uc.company_id for uc in user_companies}
|
||||
if active_cid not in valid_ids:
|
||||
active_cid = current_user.company_id
|
||||
session.pop('active_company_id', None)
|
||||
|
||||
active_company = None
|
||||
for uc in user_companies:
|
||||
if uc.company_id == active_cid:
|
||||
active_company = uc.company
|
||||
break
|
||||
|
||||
return {
|
||||
'user_companies': user_companies,
|
||||
'active_company_id': active_cid,
|
||||
'active_company': active_company,
|
||||
'has_multiple_companies': len(user_companies) > 1,
|
||||
}
|
||||
except Exception:
|
||||
return {}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@app.context_processor
|
||||
def inject_notifications():
|
||||
"""Inject unread notifications count into all templates"""
|
||||
|
||||
@ -379,6 +379,7 @@ def login():
|
||||
|
||||
# No 2FA - login directly
|
||||
login_user(user, remember=remember)
|
||||
session['active_company_id'] = user.company_id
|
||||
user.last_login = datetime.now()
|
||||
user.login_count = (user.login_count or 0) + 1
|
||||
_auto_link_person(db, user)
|
||||
@ -477,6 +478,7 @@ def verify_2fa():
|
||||
next_page = session.pop('2fa_next', None)
|
||||
|
||||
login_user(user, remember=remember)
|
||||
session['active_company_id'] = user.company_id
|
||||
session['2fa_verified'] = True
|
||||
user.last_login = datetime.now()
|
||||
user.login_count = (user.login_count or 0) + 1
|
||||
|
||||
@ -2701,3 +2701,25 @@ def sitemap_xml():
|
||||
xml_parts.append('</urlset>')
|
||||
|
||||
return Response('\n'.join(xml_parts), mimetype='application/xml')
|
||||
|
||||
|
||||
@bp.route('/api/switch-company/<int:company_id>', methods=['POST'])
|
||||
@login_required
|
||||
def switch_company(company_id):
|
||||
"""Switch the active company context for multi-company users."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
uc = db.query(UserCompany).filter_by(
|
||||
user_id=current_user.id,
|
||||
company_id=company_id
|
||||
).first()
|
||||
|
||||
if not uc:
|
||||
flash('Nie masz uprawnień do tej firmy.', 'error')
|
||||
else:
|
||||
session['active_company_id'] = company_id
|
||||
flash(f'Przełączono na firmę: {uc.company.name}', 'info')
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return redirect(request.referrer or url_for('dashboard'))
|
||||
|
||||
1
helpers/__init__.py
Normal file
1
helpers/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# helpers package
|
||||
14
helpers/company_context.py
Normal file
14
helpers/company_context.py
Normal file
@ -0,0 +1,14 @@
|
||||
"""Company context helpers for multi-company users."""
|
||||
|
||||
from flask import session
|
||||
from flask_login import current_user
|
||||
|
||||
|
||||
def get_active_company_id():
|
||||
"""Return the active company ID from session, falling back to users.company_id."""
|
||||
if not current_user.is_authenticated:
|
||||
return None
|
||||
active_id = session.get('active_company_id')
|
||||
if active_id:
|
||||
return active_id
|
||||
return current_user.company_id
|
||||
Loading…
Reference in New Issue
Block a user