From 10e51b365544e7e81fd0199dccf67d5557f2966a Mon Sep 17 00:00:00 2001 From: saulteafarmer Date: Wed, 14 May 2025 15:12:23 +0000 Subject: [PATCH] Update discord_lnbits_bot.py --- discord_lnbits_bot.py | 198 ++++++++++++++++++++---------------------- 1 file changed, 94 insertions(+), 104 deletions(-) diff --git a/discord_lnbits_bot.py b/discord_lnbits_bot.py index 3f24554..c6c9d48 100644 --- a/discord_lnbits_bot.py +++ b/discord_lnbits_bot.py @@ -7,7 +7,7 @@ import requests import logging from datetime import datetime from dateutil.relativedelta import relativedelta -from dotenv import load_dotenv +import qrcode import discord from discord import File, Embed @@ -18,6 +18,7 @@ import websockets from sqlalchemy import ( create_engine, Column, String, Integer, DateTime, Text, Enum, ForeignKey ) +from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker, relationship @@ -30,18 +31,16 @@ logging.basicConfig( ) logger = logging.getLogger("discord_lnbits_bot") -# ─── Load minimal .env & DB setup ───────────────────────────────────────────── -load_dotenv() # for DATABASE_URL only +# ─── DB Setup & Config Seeding ───────────────────────────────────────────────── DATABASE_URL = os.getenv("DATABASE_URL") if not DATABASE_URL: - logger.critical("Environment variable DATABASE_URL is missing.") + logger.critical("Missing DATABASE_URL in .env") exit(1) engine = create_engine(DATABASE_URL, future=True) SessionLocal = sessionmaker(bind=engine, expire_on_commit=False) Base = declarative_base() -# ─── ORM Models ─────────────────────────────────────────────────────────────── class Config(Base): __tablename__ = "config" key = Column(String, primary_key=True) @@ -56,7 +55,10 @@ class Plan(Base): channel_id = Column(String(64), nullable=False) price_sats = Column(Integer, nullable=False) invoice_message = Column(Text, nullable=False) - expiry_type = Column(Enum("fixed_date","rolling_days","calendar_month", name="expiry_type"), nullable=False) + expiry_type = Column( + Enum("fixed_date", "rolling_days", "calendar_month", name="expiry_type"), + nullable=False + ) duration_days = Column(Integer, default=30) subs = relationship("Subscription", back_populates="plan") @@ -67,25 +69,36 @@ class Subscription(Base): user_id = Column(String(64), nullable=False) invoice_hash = Column(String(256), unique=True, nullable=False) invoice_request = Column(Text, nullable=False) - status = Column(Enum("pending","active","expired","cancelled", name="sub_status"), nullable=False, default="pending") + status = Column( + Enum("pending","active","expired","cancelled", name="sub_status"), + nullable=False, default="pending" + ) created_at = Column(DateTime, default=datetime.utcnow) paid_at = Column(DateTime) expires_at = Column(DateTime) role_assigned_at = Column(DateTime) role_removed_at = Column(DateTime) - raw_invoice_json = Column(Text) - raw_payment_json = Column(Text) + raw_invoice_json = Column(JSONB) + raw_payment_json = Column(JSONB) plan = relationship("Plan", back_populates="subs") +# create tables & seed Config keys Base.metadata.create_all(engine) logger.info("✅ Database schema ready.") +session = SessionLocal() +for key in ("discord_token", "guild_id", "lnbits_url", "lnbits_api_key"): + if not session.get(Config, key): + session.add(Config(key=key, value="")) +session.commit() +session.close() + # ─── Config helper ──────────────────────────────────────────────────────────── def get_cfg(key: str) -> str: - session = SessionLocal() - cfg = session.get(Config, key) - session.close() - return cfg.value if cfg else "" + db = SessionLocal() + row = db.get(Config, key) + db.close() + return row.value if row else "" # Read core settings from DB DISCORD_TOKEN = get_cfg("discord_token") @@ -93,15 +106,14 @@ GUILD_ID = int(get_cfg("guild_id") or 0) LNBITS_URL = get_cfg("lnbits_url") LNBITS_API_KEY = get_cfg("lnbits_api_key") -# Validate -for var_name, val in [ +for name, val in ( ("discord_token", DISCORD_TOKEN), ("guild_id", GUILD_ID), ("lnbits_url", LNBITS_URL), ("lnbits_api_key", LNBITS_API_KEY), -]: +): if not val: - logger.critical(f"Config key '{var_name}' is missing or empty in DB.") + logger.critical(f"Config '{name}' is empty in DB. Fill it via the UI first.") exit(1) # Build WebSocket URL @@ -119,17 +131,16 @@ intents = discord.Intents.default() intents.members = True bot = commands.Bot(command_prefix="!", intents=intents) -# ─── Utility: expiry computation ─────────────────────────────────────────────── +# ─── Expiry computation ──────────────────────────────────────────────────────── def compute_expiry(paid_at: datetime, plan: Plan) -> datetime: if plan.expiry_type == "calendar_month": first = paid_at.replace(day=1, hour=0, minute=0, second=0, microsecond=0) return first + relativedelta(months=1) - elif plan.expiry_type == "rolling_days": + if plan.expiry_type == "rolling_days": return paid_at + relativedelta(days=plan.duration_days) - else: # fixed_date - return paid_at # assume UI sets exact expires_at elsewhere + return paid_at # fixed_date handled via UI -# ─── Handle a plan purchase ─────────────────────────────────────────────────── +# ─── Handle plan purchase ───────────────────────────────────────────────────── async def handle_plan_purchase(interaction: discord.Interaction, plan: Plan): invoice_ch = bot.get_channel(int(plan.channel_id)) if not invoice_ch: @@ -149,47 +160,40 @@ async def handle_plan_purchase(interaction: discord.Interaction, plan: Plan): loop = asyncio.get_running_loop() def create_invoice(): - resp = requests.post( + r = requests.post( f"{LNBITS_URL}/api/v1/payments", - json=invoice_data, - headers=headers, - timeout=10 + json=invoice_data, headers=headers, timeout=10 ) - resp.raise_for_status() - return resp.json() + r.raise_for_status() + return r.json() try: inv = await loop.run_in_executor(None, create_invoice) except Exception: logger.exception("Invoice creation failed") return await interaction.response.send_message( - "❌ Failed to generate invoice. Try again later.", - ephemeral=True + "❌ Failed to generate invoice.", ephemeral=True ) - pay_req = inv["bolt11"] - pay_hash = inv["payment_hash"] + pay_req = inv["bolt11"] + pay_hash = inv["payment_hash"] - # Store pending subscription db = SessionLocal() sub = Subscription( plan_id = plan.id, user_id = str(interaction.user.id), invoice_hash = pay_hash, invoice_request = pay_req, - raw_invoice_json= json.dumps(inv) + raw_invoice_json= inv ) - db.add(sub) - db.commit() - db.close() - logger.info(f"Stored pending subscription {pay_hash} for user {interaction.user.id}") + db.add(sub); db.commit(); db.close() + logger.info(f"Stored pending {pay_hash} for user {interaction.user.id}") - # Generate QR code def make_qr_buf(): - img = requests.compat.BytesIO() - qrcode.make(pay_req).save(img, format="PNG") - img.seek(0) - return img + buf = io.BytesIO() + qrcode.make(pay_req).save(buf, format="PNG") + buf.seek(0) + return buf qr_buf = await loop.run_in_executor(None, make_qr_buf) qr_file = File(qr_buf, filename="invoice.png") @@ -208,7 +212,7 @@ async def handle_plan_purchase(interaction: discord.Interaction, plan: Plan): f"✅ Invoice posted in {invoice_ch.mention}", ephemeral=True ) -# ─── WebSocket listener for payments ────────────────────────────────────────── +# ─── WS listener ─────────────────────────────────────────────────────────────── async def lnbits_ws_listener(): await bot.wait_until_ready() logger.info(f"Listening to LNbits WS at {LNBITS_WS_URL}") @@ -216,57 +220,50 @@ async def lnbits_ws_listener(): try: async with websockets.connect(LNBITS_WS_URL) as ws: async for msg in ws: - try: - data = json.loads(msg) - pay = data.get("payment", {}) - hsh = pay.get("checking_id") or pay.get("payment_hash") - if not (hsh and pay.get("status") == "success"): - continue + data = json.loads(msg) + pay = data.get("payment", {}) + hsh = pay.get("checking_id") or pay.get("payment_hash") + if not (hsh and pay.get("status")=="success"): + continue - # Activate subscription - db = SessionLocal() - sub = db.query(Subscription)\ - .filter_by(invoice_hash=hsh, status="pending")\ - .first() - if not sub: - db.close() - continue + db = SessionLocal() + sub = db.query(Subscription)\ + .filter_by(invoice_hash=hsh, status="pending")\ + .first() + if not sub: + db.close(); continue - paid_at = datetime.utcnow() - expires_at = compute_expiry(paid_at, sub.plan) + paid_at = datetime.utcnow() + expires_at = compute_expiry(paid_at, sub.plan) - sub.status = "active" - sub.paid_at = paid_at - sub.expires_at = expires_at - sub.raw_payment_json = json.dumps(pay) - db.commit() - db.close() + sub.status = "active" + sub.paid_at = paid_at + sub.expires_at = expires_at + sub.raw_payment_json = pay + db.commit(); db.close() - # Assign role & announce - guild = bot.get_guild(GUILD_ID) - member = guild.get_member(int(sub.user_id)) or await guild.fetch_member(int(sub.user_id)) - role = guild.get_role(int(sub.plan.role_id)) - chan = bot.get_channel(int(sub.plan.channel_id)) + guild = bot.get_guild(GUILD_ID) + member = guild.get_member(int(sub.user_id)) or await guild.fetch_member(int(sub.user_id)) + role = guild.get_role(int(sub.plan.role_id)) + chan = bot.get_channel(int(sub.plan.channel_id)) - await member.add_roles(role, reason="Subscription paid") - await chan.send( - f"🎉 {member.mention} paid **{sub.plan.price_sats} sats**" - f" for **{sub.plan.name}**! Expires {expires_at:%Y-%m-%d}." - ) - except Exception: - logger.exception("Error processing WS message") + await member.add_roles(role, reason="Subscription paid") + await chan.send( + f"🎉 {member.mention} paid **{sub.plan.price_sats} sats** " + f"for **{sub.plan.name}**, expires {expires_at:%Y-%m-%d}." + ) except Exception: - logger.exception("WS connection error, retry in 10s") + logger.exception("WS error; reconnecting in 10s") await asyncio.sleep(10) -# ─── Cleanup expired subscriptions ───────────────────────────────────────────── +# ─── Cleanup expired ─────────────────────────────────────────────────────────── async def cleanup_expired(): now = datetime.utcnow() db = SessionLocal() - expired = db.query(Subscription)\ - .filter(Subscription.status=="active", Subscription.expires_at <= now)\ - .all() - for sub in expired: + rows = db.query(Subscription)\ + .filter(Subscription.status=="active", Subscription.expires_at <= now)\ + .all() + for sub in rows: guild = bot.get_guild(GUILD_ID) member = guild.get_member(int(sub.user_id)) role = guild.get_role(int(sub.plan.role_id)) @@ -276,42 +273,35 @@ async def cleanup_expired(): sub.status = "expired" sub.role_removed_at = now db.commit() - logger.info(f"Removed {role.name} from {member.display_name}") + logger.info(f"Removed expired {role.name} from {member.display_name}") except Exception: - logger.exception("Failed to remove expired role") + logger.exception("Error removing expired role") db.close() -# ─── on_ready: dynamic commands + start jobs ────────────────────────────────── +# ─── on_ready & scheduler ───────────────────────────────────────────────────── @bot.event async def on_ready(): - logger.info(f"Logged in as {bot.user} (ID: {bot.user.id})") + logger.info(f"Bot logged in as {bot.user} (ID {bot.user.id})") - # Register commands for each plan + # dynamic command registration db = SessionLocal() plans = db.query(Plan).all() for plan in plans: - @bot.tree.command(name=plan.command_name, description=f"Buy {plan.name}") + @bot.tree.command(name=plan.command_name, + description=f"Buy {plan.name} for {plan.price_sats} sats") async def _cmd(interaction: discord.Interaction, plan=plan): await handle_plan_purchase(interaction, plan) db.close() await bot.tree.sync() - # Start WS listener + # start WS listener bot.loop.create_task(lnbits_ws_listener()) - # Schedule cleanup daily at midnight UTC - scheduler = AsyncIOScheduler(timezone="UTC") - scheduler.add_job(cleanup_expired, "cron", hour=0, minute=0) - scheduler.start() + # schedule midnight‐UTC expiry check + sched = AsyncIOScheduler(timezone="UTC") + sched.add_job(cleanup_expired, "cron", hour=0, minute=0) + sched.start() -# ─── Entry point ─────────────────────────────────────────────────────────────── +# ─── Run ─────────────────────────────────────────────────────────────────────── if __name__ == "__main__": - # Seed Config table with required keys if missing - session = SessionLocal() - for key in ("discord_token","guild_id","lnbits_url","lnbits_api_key"): - if not session.get(Config, key): - session.add(Config(key=key, value="")) - session.commit() - session.close() - bot.run(DISCORD_TOKEN)