#!/usr/bin/env python3 import os import asyncio import json import io import requests import logging from datetime import datetime from dateutil.relativedelta import relativedelta import qrcode import discord from discord import File, Embed from discord.ext import commands import websockets from sqlalchemy import ( create_engine, Column, String, Integer, DateTime, Text, Enum, ForeignKey ) from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker, relationship from apscheduler.schedulers.asyncio import AsyncIOScheduler # ─── Logging ────────────────────────────────────────────────────────────── logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s | %(message)s") logger = logging.getLogger("discord_lnbits_bot") # ─── Database & Config Seeding ──────────────────────────────────────────── DATABASE_URL = os.environ["DATABASE_URL"] engine = create_engine(DATABASE_URL, future=True) SessionLocal = sessionmaker(bind=engine, expire_on_commit=False) Base = declarative_base() class Config(Base): __tablename__ = "config" key = Column(String, primary_key=True) value = Column(Text, nullable=False) class Plan(Base): __tablename__ = "plans" id = Column(Integer, primary_key=True) name = Column(Text, nullable=False, unique=True) command_name = Column(Text, nullable=False, unique=True) role_id = Column(String(64), nullable=False) channel_id = Column(String(64), nullable=False) price_sats = Column(Integer, nullable=False) invoice_message = Column(Text, nullable=False) expiry_type = Column(Enum("fixed_date","rolling_days","calendar_month", name="expiry_type"), nullable=False) duration_days = Column(Integer, default=30) subs = relationship("Subscription", back_populates="plan") class Subscription(Base): __tablename__ = "subscriptions" id = Column(Integer, primary_key=True) plan_id = Column(Integer, ForeignKey("plans.id"), nullable=False) user_id = Column(String(64), nullable=False) invoice_hash = Column(String(256), unique=True, nullable=False) invoice_request = Column(Text, nullable=False) status = Column(Enum("pending","active","expired","cancelled", name="sub_status"), nullable=False, default="pending") created_at = Column(DateTime, default=datetime.utcnow) paid_at = Column(DateTime) expires_at = Column(DateTime) role_assigned_at = Column(DateTime) role_removed_at = Column(DateTime) raw_invoice_json = Column(JSONB) raw_payment_json = Column(JSONB) plan = relationship("Plan", back_populates="subs") Base.metadata.create_all(engine) logger.info("✅ Database schema ready.") # Seed Config keys (if missing) session = SessionLocal() for key in ("discord_token", "guild_id", "lnbits_url", "lnbits_api_key"): if not session.get(Config, key): session.add(Config(key=key, value="")) session.commit() session.close() # ─── Config helper ───────────────────────────────────────────────────────── def get_cfg(k: str) -> str: db = SessionLocal() row = db.get(Config, k) db.close() return row.value if row else "" DISCORD_TOKEN = get_cfg("discord_token") GUILD_ID = int(get_cfg("guild_id") or 0) LNBITS_URL = get_cfg("lnbits_url") LNBITS_API_KEY = get_cfg("lnbits_api_key") for name, val in (("discord_token", DISCORD_TOKEN), ("guild_id", GUILD_ID), ("lnbits_url", LNBITS_URL), ("lnbits_api_key", LNBITS_API_KEY)): if not val: logger.critical(f"Config '{name}' is empty in DB. Fill via UI first.") exit(1) # Build WS URL if LNBITS_URL.startswith("https://"): ws_base = LNBITS_URL.replace("https://", "wss://", 1) elif LNBITS_URL.startswith("http://"): ws_base = LNBITS_URL.replace("http://", "ws://", 1) else: logger.critical("LNBITS_URL must start with http:// or https://") exit(1) LNBITS_WS_URL = f"{ws_base}/api/v1/ws/{LNBITS_API_KEY}" # ─── Discord Bot ────────────────────────────────────────────────────────── intents = discord.Intents.default(); intents.members = True bot = commands.Bot(command_prefix="!", intents=intents) def compute_expiry(paid_at: datetime, plan: Plan) -> datetime: if plan.expiry_type == "calendar_month": first = paid_at.replace(day=1, hour=0, minute=0, second=0, microsecond=0) return first + relativedelta(months=1) if plan.expiry_type == "rolling_days": return paid_at + relativedelta(days=plan.duration_days) return paid_at async def handle_plan_purchase(interaction, plan: Plan): invoice_ch = bot.get_channel(int(plan.channel_id)) if not invoice_ch: return await interaction.response.send_message("❌ Invoice channel not found.", ephemeral=True) invoice_data = {"out": False, "amount": plan.price_sats, "memo": plan.invoice_message} headers = {"X-Api-Key": LNBITS_API_KEY, "Content-Type": "application/json"} loop = asyncio.get_running_loop() def create_invoice(): r = requests.post(f"{LNBITS_URL}/api/v1/payments", json=invoice_data, headers=headers, timeout=10) r.raise_for_status() return r.json() try: inv = await loop.run_in_executor(None, create_invoice) except Exception: logger.exception("Invoice creation failed") return await interaction.response.send_message("❌ Could not generate invoice.", ephemeral=True) pay_req, pay_hash = inv["bolt11"], inv["payment_hash"] db = SessionLocal() sub = Subscription(plan_id=plan.id, user_id=str(interaction.user.id), invoice_hash=pay_hash, invoice_request=pay_req, raw_invoice_json=inv) db.add(sub); db.commit(); db.close() def make_qr_buf(): buf = io.BytesIO(); qrcode.make(pay_req).save(buf, format="PNG"); buf.seek(0); return buf qr_buf = await loop.run_in_executor(None, make_qr_buf) qr_file= File(qr_buf, "invoice.png") embed = Embed(title=f"⚡ Pay {plan.price_sats} sats for {plan.name} ⚡", description=plan.invoice_message, color=discord.Color.gold()) embed.add_field(name="Invoice", value=f"```{pay_req}```", inline=False) embed.set_image(url="attachment://invoice.png") embed.set_footer(text=f"Hash: {pay_hash}") await invoice_ch.send(content=interaction.user.mention, embed=embed, file=qr_file) await interaction.response.send_message(f"✅ Invoice posted in {invoice_ch.mention}", ephemeral=True) async def lnbits_ws_listener(): await bot.wait_until_ready() logger.info(f"Listening on WS {LNBITS_WS_URL}") while True: try: async with websockets.connect(LNBITS_WS_URL) as ws: async for msg in ws: data = json.loads(msg); pay = data.get("payment", {}) hsh = pay.get("checking_id") or pay.get("payment_hash") if not (hsh and pay.get("status")=="success"): continue db = SessionLocal() sub= db.query(Subscription).filter_by(invoice_hash=hsh, status="pending").first() if not sub: db.close(); continue paid_at= datetime.utcnow() expires_at= compute_expiry(paid_at, sub.plan) sub.status=sub.paid_at=paid_at; sub.expires_at=expires_at sub.raw_payment_json=pay; db.commit(); db.close() guild = bot.get_guild(GUILD_ID) member = guild.get_member(int(sub.user_id)) or await guild.fetch_member(int(sub.user_id)) role = guild.get_role(int(sub.plan.role_id)) chan = bot.get_channel(int(sub.plan.channel_id)) await member.add_roles(role, reason="Subscription paid") await chan.send(f"🎉 {member.mention} paid **{sub.plan.price_sats} sats** " f"for **{sub.plan.name}**, expires {expires_at:%Y-%m-%d}.") except Exception: logger.exception("WS error, reconnecting in 10s"); await asyncio.sleep(10) async def cleanup_expired(): now= datetime.utcnow(); db= SessionLocal() rows= db.query(Subscription).filter(Subscription.status=="active", Subscription.expires_at<=now).all() for sub in rows: guild = bot.get_guild(GUILD_ID) member = guild.get_member(int(sub.user_id)) role = guild.get_role(int(sub.plan.role_id)) if member and role in member.roles: try: await member.remove_roles(role, reason="Subscription expired") sub.status="expired"; sub.role_removed_at=now; db.commit() logger.info(f"Removed expired role from {member.display_name}") except Exception: logger.exception("Failed to remove expired role") db.close() @bot.event async def on_ready(): logger.info(f"Logged in as {bot.user}") db = SessionLocal(); plans = db.query(Plan).all() for plan in plans: @bot.tree.command(name=plan.command_name, description=f"Buy {plan.name}") async def _cmd(interaction, plan=plan): await handle_plan_purchase(interaction, plan) db.close() await bot.tree.sync() bot.loop.create_task(lnbits_ws_listener()) sched= AsyncIOScheduler(timezone="UTC") sched.add_job(cleanup_expired, "cron", hour=0, minute=0) sched.start() if __name__=="__main__": bot.run(DISCORD_TOKEN)