From a07cdecc32dfe3a77156a0a9b123939e8b116947 Mon Sep 17 00:00:00 2001 From: saulteafarmer Date: Wed, 14 May 2025 16:15:24 +0000 Subject: [PATCH] Update discord_lnbits_bot.py --- discord_lnbits_bot.py | 198 +++++++++++++----------------------------- 1 file changed, 58 insertions(+), 140 deletions(-) diff --git a/discord_lnbits_bot.py b/discord_lnbits_bot.py index c6c9d48..09254a4 100644 --- a/discord_lnbits_bot.py +++ b/discord_lnbits_bot.py @@ -24,19 +24,12 @@ from sqlalchemy.orm import sessionmaker, relationship from apscheduler.schedulers.asyncio import AsyncIOScheduler -# ─── Logging ────────────────────────────────────────────────────────────────── -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(levelname)s %(name)s | %(message)s" -) +# ─── Logging ────────────────────────────────────────────────────────────── +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s | %(message)s") logger = logging.getLogger("discord_lnbits_bot") -# ─── DB Setup & Config Seeding ───────────────────────────────────────────────── -DATABASE_URL = os.getenv("DATABASE_URL") -if not DATABASE_URL: - logger.critical("Missing DATABASE_URL in .env") - exit(1) - +# ─── Database & Config Seeding ──────────────────────────────────────────── +DATABASE_URL = os.environ["DATABASE_URL"] engine = create_engine(DATABASE_URL, future=True) SessionLocal = sessionmaker(bind=engine, expire_on_commit=False) Base = declarative_base() @@ -55,10 +48,7 @@ class Plan(Base): channel_id = Column(String(64), nullable=False) price_sats = Column(Integer, nullable=False) invoice_message = Column(Text, nullable=False) - expiry_type = Column( - Enum("fixed_date", "rolling_days", "calendar_month", name="expiry_type"), - nullable=False - ) + expiry_type = Column(Enum("fixed_date","rolling_days","calendar_month", name="expiry_type"), nullable=False) duration_days = Column(Integer, default=30) subs = relationship("Subscription", back_populates="plan") @@ -69,10 +59,7 @@ class Subscription(Base): user_id = Column(String(64), nullable=False) invoice_hash = Column(String(256), unique=True, nullable=False) invoice_request = Column(Text, nullable=False) - status = Column( - Enum("pending","active","expired","cancelled", name="sub_status"), - nullable=False, default="pending" - ) + status = Column(Enum("pending","active","expired","cancelled", name="sub_status"), nullable=False, default="pending") created_at = Column(DateTime, default=datetime.utcnow) paid_at = Column(DateTime) expires_at = Column(DateTime) @@ -82,10 +69,10 @@ class Subscription(Base): raw_payment_json = Column(JSONB) plan = relationship("Plan", back_populates="subs") -# create tables & seed Config keys Base.metadata.create_all(engine) logger.info("✅ Database schema ready.") +# Seed Config keys (if missing) session = SessionLocal() for key in ("discord_token", "guild_id", "lnbits_url", "lnbits_api_key"): if not session.get(Config, key): @@ -93,30 +80,25 @@ for key in ("discord_token", "guild_id", "lnbits_url", "lnbits_api_key"): session.commit() session.close() -# ─── Config helper ──────────────────────────────────────────────────────────── -def get_cfg(key: str) -> str: +# ─── Config helper ───────────────────────────────────────────────────────── +def get_cfg(k: str) -> str: db = SessionLocal() - row = db.get(Config, key) + row = db.get(Config, k) db.close() return row.value if row else "" -# Read core settings from DB DISCORD_TOKEN = get_cfg("discord_token") GUILD_ID = int(get_cfg("guild_id") or 0) LNBITS_URL = get_cfg("lnbits_url") LNBITS_API_KEY = get_cfg("lnbits_api_key") -for name, val in ( - ("discord_token", DISCORD_TOKEN), - ("guild_id", GUILD_ID), - ("lnbits_url", LNBITS_URL), - ("lnbits_api_key", LNBITS_API_KEY), -): +for name, val in (("discord_token", DISCORD_TOKEN), ("guild_id", GUILD_ID), + ("lnbits_url", LNBITS_URL), ("lnbits_api_key", LNBITS_API_KEY)): if not val: - logger.critical(f"Config '{name}' is empty in DB. Fill it via the UI first.") + logger.critical(f"Config '{name}' is empty in DB. Fill via UI first.") exit(1) -# Build WebSocket URL +# Build WS URL if LNBITS_URL.startswith("https://"): ws_base = LNBITS_URL.replace("https://", "wss://", 1) elif LNBITS_URL.startswith("http://"): @@ -126,121 +108,76 @@ else: exit(1) LNBITS_WS_URL = f"{ws_base}/api/v1/ws/{LNBITS_API_KEY}" -# ─── Discord Bot Setup ──────────────────────────────────────────────────────── -intents = discord.Intents.default() -intents.members = True +# ─── Discord Bot ────────────────────────────────────────────────────────── +intents = discord.Intents.default(); intents.members = True bot = commands.Bot(command_prefix="!", intents=intents) -# ─── Expiry computation ──────────────────────────────────────────────────────── def compute_expiry(paid_at: datetime, plan: Plan) -> datetime: if plan.expiry_type == "calendar_month": first = paid_at.replace(day=1, hour=0, minute=0, second=0, microsecond=0) return first + relativedelta(months=1) if plan.expiry_type == "rolling_days": return paid_at + relativedelta(days=plan.duration_days) - return paid_at # fixed_date handled via UI + return paid_at -# ─── Handle plan purchase ───────────────────────────────────────────────────── -async def handle_plan_purchase(interaction: discord.Interaction, plan: Plan): +async def handle_plan_purchase(interaction, plan: Plan): invoice_ch = bot.get_channel(int(plan.channel_id)) if not invoice_ch: - return await interaction.response.send_message( - "❌ Invoice channel not found.", ephemeral=True - ) - - invoice_data = { - "out": False, - "amount": plan.price_sats, - "memo": plan.invoice_message - } - headers = { - "X-Api-Key": LNBITS_API_KEY, - "Content-Type": "application/json" - } + return await interaction.response.send_message("❌ Invoice channel not found.", ephemeral=True) + invoice_data = {"out": False, "amount": plan.price_sats, "memo": plan.invoice_message} + headers = {"X-Api-Key": LNBITS_API_KEY, "Content-Type": "application/json"} loop = asyncio.get_running_loop() def create_invoice(): - r = requests.post( - f"{LNBITS_URL}/api/v1/payments", - json=invoice_data, headers=headers, timeout=10 - ) + r = requests.post(f"{LNBITS_URL}/api/v1/payments", json=invoice_data, headers=headers, timeout=10) r.raise_for_status() return r.json() - try: inv = await loop.run_in_executor(None, create_invoice) except Exception: logger.exception("Invoice creation failed") - return await interaction.response.send_message( - "❌ Failed to generate invoice.", ephemeral=True - ) - - pay_req = inv["bolt11"] - pay_hash = inv["payment_hash"] + return await interaction.response.send_message("❌ Could not generate invoice.", ephemeral=True) + pay_req, pay_hash = inv["bolt11"], inv["payment_hash"] db = SessionLocal() - sub = Subscription( - plan_id = plan.id, - user_id = str(interaction.user.id), - invoice_hash = pay_hash, - invoice_request = pay_req, - raw_invoice_json= inv - ) + sub = Subscription(plan_id=plan.id, user_id=str(interaction.user.id), + invoice_hash=pay_hash, invoice_request=pay_req, + raw_invoice_json=inv) db.add(sub); db.commit(); db.close() - logger.info(f"Stored pending {pay_hash} for user {interaction.user.id}") def make_qr_buf(): - buf = io.BytesIO() - qrcode.make(pay_req).save(buf, format="PNG") - buf.seek(0) - return buf + buf = io.BytesIO(); qrcode.make(pay_req).save(buf, format="PNG"); buf.seek(0); return buf + qr_buf = await loop.run_in_executor(None, make_qr_buf) + qr_file= File(qr_buf, "invoice.png") - qr_buf = await loop.run_in_executor(None, make_qr_buf) - qr_file = File(qr_buf, filename="invoice.png") - - embed = Embed( - title=f"⚡ Pay {plan.price_sats} sats for {plan.name} ⚡", - description=plan.invoice_message, - color=discord.Color.gold() - ) + embed = Embed(title=f"⚡ Pay {plan.price_sats} sats for {plan.name} ⚡", + description=plan.invoice_message, color=discord.Color.gold()) embed.add_field(name="Invoice", value=f"```{pay_req}```", inline=False) embed.set_image(url="attachment://invoice.png") embed.set_footer(text=f"Hash: {pay_hash}") await invoice_ch.send(content=interaction.user.mention, embed=embed, file=qr_file) - await interaction.response.send_message( - f"✅ Invoice posted in {invoice_ch.mention}", ephemeral=True - ) + await interaction.response.send_message(f"✅ Invoice posted in {invoice_ch.mention}", ephemeral=True) -# ─── WS listener ─────────────────────────────────────────────────────────────── async def lnbits_ws_listener(): await bot.wait_until_ready() - logger.info(f"Listening to LNbits WS at {LNBITS_WS_URL}") + logger.info(f"Listening on WS {LNBITS_WS_URL}") while True: try: async with websockets.connect(LNBITS_WS_URL) as ws: async for msg in ws: - data = json.loads(msg) - pay = data.get("payment", {}) - hsh = pay.get("checking_id") or pay.get("payment_hash") - if not (hsh and pay.get("status")=="success"): - continue + data = json.loads(msg); pay = data.get("payment", {}) + hsh = pay.get("checking_id") or pay.get("payment_hash") + if not (hsh and pay.get("status")=="success"): continue - db = SessionLocal() - sub = db.query(Subscription)\ - .filter_by(invoice_hash=hsh, status="pending")\ - .first() - if not sub: - db.close(); continue + db = SessionLocal() + sub= db.query(Subscription).filter_by(invoice_hash=hsh, status="pending").first() + if not sub: db.close(); continue - paid_at = datetime.utcnow() - expires_at = compute_expiry(paid_at, sub.plan) - - sub.status = "active" - sub.paid_at = paid_at - sub.expires_at = expires_at - sub.raw_payment_json = pay - db.commit(); db.close() + paid_at= datetime.utcnow() + expires_at= compute_expiry(paid_at, sub.plan) + sub.status=sub.paid_at=paid_at; sub.expires_at=expires_at + sub.raw_payment_json=pay; db.commit(); db.close() guild = bot.get_guild(GUILD_ID) member = guild.get_member(int(sub.user_id)) or await guild.fetch_member(int(sub.user_id)) @@ -248,21 +185,14 @@ async def lnbits_ws_listener(): chan = bot.get_channel(int(sub.plan.channel_id)) await member.add_roles(role, reason="Subscription paid") - await chan.send( - f"🎉 {member.mention} paid **{sub.plan.price_sats} sats** " - f"for **{sub.plan.name}**, expires {expires_at:%Y-%m-%d}." - ) + await chan.send(f"🎉 {member.mention} paid **{sub.plan.price_sats} sats** " + f"for **{sub.plan.name}**, expires {expires_at:%Y-%m-%d}.") except Exception: - logger.exception("WS error; reconnecting in 10s") - await asyncio.sleep(10) + logger.exception("WS error, reconnecting in 10s"); await asyncio.sleep(10) -# ─── Cleanup expired ─────────────────────────────────────────────────────────── async def cleanup_expired(): - now = datetime.utcnow() - db = SessionLocal() - rows = db.query(Subscription)\ - .filter(Subscription.status=="active", Subscription.expires_at <= now)\ - .all() + now= datetime.utcnow(); db= SessionLocal() + rows= db.query(Subscription).filter(Subscription.status=="active", Subscription.expires_at<=now).all() for sub in rows: guild = bot.get_guild(GUILD_ID) member = guild.get_member(int(sub.user_id)) @@ -270,38 +200,26 @@ async def cleanup_expired(): if member and role in member.roles: try: await member.remove_roles(role, reason="Subscription expired") - sub.status = "expired" - sub.role_removed_at = now - db.commit() - logger.info(f"Removed expired {role.name} from {member.display_name}") + sub.status="expired"; sub.role_removed_at=now; db.commit() + logger.info(f"Removed expired role from {member.display_name}") except Exception: - logger.exception("Error removing expired role") + logger.exception("Failed to remove expired role") db.close() -# ─── on_ready & scheduler ───────────────────────────────────────────────────── @bot.event async def on_ready(): - logger.info(f"Bot logged in as {bot.user} (ID {bot.user.id})") - - # dynamic command registration - db = SessionLocal() - plans = db.query(Plan).all() + logger.info(f"Logged in as {bot.user}") + db = SessionLocal(); plans = db.query(Plan).all() for plan in plans: - @bot.tree.command(name=plan.command_name, - description=f"Buy {plan.name} for {plan.price_sats} sats") - async def _cmd(interaction: discord.Interaction, plan=plan): + @bot.tree.command(name=plan.command_name, description=f"Buy {plan.name}") + async def _cmd(interaction, plan=plan): await handle_plan_purchase(interaction, plan) db.close() await bot.tree.sync() - - # start WS listener bot.loop.create_task(lnbits_ws_listener()) - - # schedule midnight‐UTC expiry check - sched = AsyncIOScheduler(timezone="UTC") + sched= AsyncIOScheduler(timezone="UTC") sched.add_job(cleanup_expired, "cron", hour=0, minute=0) sched.start() -# ─── Run ─────────────────────────────────────────────────────────────────────── -if __name__ == "__main__": +if __name__=="__main__": bot.run(DISCORD_TOKEN)