diff --git a/discord_lnbits_bot.py b/discord_lnbits_bot.py index f491bcd..3f24554 100644 --- a/discord_lnbits_bot.py +++ b/discord_lnbits_bot.py @@ -3,10 +3,10 @@ import os import asyncio import json import io -import qrcode import requests import logging from datetime import datetime +from dateutil.relativedelta import relativedelta from dotenv import load_dotenv import discord @@ -16,199 +16,138 @@ from discord.ext import commands import websockets from sqlalchemy import ( - create_engine, Column, String, Integer, DateTime + create_engine, Column, String, Integer, DateTime, Text, Enum, ForeignKey ) from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import sessionmaker, relationship -# ─── Environment & Logging ──────────────────────────────────────────────────── -load_dotenv() # loads .env into os.environ +from apscheduler.schedulers.asyncio import AsyncIOScheduler +# ─── Logging ────────────────────────────────────────────────────────────────── logging.basicConfig( level=logging.INFO, - format="%(asctime)s %(levelname)7s %(name)s | %(message)s" + format="%(asctime)s %(levelname)s %(name)s | %(message)s" ) logger = logging.getLogger("discord_lnbits_bot") -# ─── Config from ENV ────────────────────────────────────────────────────────── -DISCORD_TOKEN = os.getenv("DISCORD_TOKEN") -GUILD_ID = int(os.getenv("GUILD_ID", 0)) -ROLE_ID = int(os.getenv("ROLE_ID", 0)) -LNBITS_URL = os.getenv("LNBITS_URL", "").rstrip("/") -LNBITS_API_KEY = os.getenv("LNBITS_API_KEY", "") -PRICE = int(os.getenv("PRICE", 0)) -CHANNEL_ID = int(os.getenv("CHANNEL_ID", 0)) -INVOICE_MESSAGE_TEMPLATE = os.getenv("INVOICE_MESSAGE", "Invoice for your purchase.") -COMMAND_NAME = os.getenv("COMMAND_NAME", "support") -DATABASE_URL = os.getenv("DATABASE_URL") - -# Ensure all required env vars are present -for var in ( - "DISCORD_TOKEN", "GUILD_ID", "ROLE_ID", - "LNBITS_URL", "LNBITS_API_KEY", "PRICE", - "CHANNEL_ID", "DATABASE_URL" -): - if not globals()[var]: - logger.critical(f"Environment variable {var} is missing.") - exit(1) - -# Build WS URL from LNBITS_URL -if LNBITS_URL.startswith("https://"): - base_ws = LNBITS_URL.replace("https://", "wss://", 1) -elif LNBITS_URL.startswith("http://"): - base_ws = LNBITS_URL.replace("http://", "ws://", 1) -else: - logger.critical("LNBITS_URL must start with http:// or https://") +# ─── Load minimal .env & DB setup ───────────────────────────────────────────── +load_dotenv() # for DATABASE_URL only +DATABASE_URL = os.getenv("DATABASE_URL") +if not DATABASE_URL: + logger.critical("Environment variable DATABASE_URL is missing.") exit(1) -LNBITS_WS_URL = f"{base_ws}/api/v1/ws/{LNBITS_API_KEY}" - -# ─── Database Setup ─────────────────────────────────────────────────────────── -Base = declarative_base() - -class Payment(Base): - __tablename__ = "payments" - id = Column(Integer, primary_key=True, autoincrement=True) - payment_hash = Column(String(256), unique=True, nullable=False, index=True) - user_id = Column(String(64), nullable=False, index=True) - amount = Column(Integer, nullable=False) - paid_at = Column(DateTime, default=datetime.utcnow, nullable=False) engine = create_engine(DATABASE_URL, future=True) SessionLocal = sessionmaker(bind=engine, expire_on_commit=False) +Base = declarative_base() + +# ─── ORM Models ─────────────────────────────────────────────────────────────── +class Config(Base): + __tablename__ = "config" + key = Column(String, primary_key=True) + value = Column(Text, nullable=False) + +class Plan(Base): + __tablename__ = "plans" + id = Column(Integer, primary_key=True) + name = Column(Text, nullable=False, unique=True) + command_name = Column(Text, nullable=False, unique=True) + role_id = Column(String(64), nullable=False) + channel_id = Column(String(64), nullable=False) + price_sats = Column(Integer, nullable=False) + invoice_message = Column(Text, nullable=False) + expiry_type = Column(Enum("fixed_date","rolling_days","calendar_month", name="expiry_type"), nullable=False) + duration_days = Column(Integer, default=30) + subs = relationship("Subscription", back_populates="plan") + +class Subscription(Base): + __tablename__ = "subscriptions" + id = Column(Integer, primary_key=True) + plan_id = Column(Integer, ForeignKey("plans.id"), nullable=False) + user_id = Column(String(64), nullable=False) + invoice_hash = Column(String(256), unique=True, nullable=False) + invoice_request = Column(Text, nullable=False) + status = Column(Enum("pending","active","expired","cancelled", name="sub_status"), nullable=False, default="pending") + created_at = Column(DateTime, default=datetime.utcnow) + paid_at = Column(DateTime) + expires_at = Column(DateTime) + role_assigned_at = Column(DateTime) + role_removed_at = Column(DateTime) + raw_invoice_json = Column(Text) + raw_payment_json = Column(Text) + plan = relationship("Plan", back_populates="subs") + Base.metadata.create_all(engine) -logger.info("✅ Database tables ensured.") +logger.info("✅ Database schema ready.") + +# ─── Config helper ──────────────────────────────────────────────────────────── +def get_cfg(key: str) -> str: + session = SessionLocal() + cfg = session.get(Config, key) + session.close() + return cfg.value if cfg else "" + +# Read core settings from DB +DISCORD_TOKEN = get_cfg("discord_token") +GUILD_ID = int(get_cfg("guild_id") or 0) +LNBITS_URL = get_cfg("lnbits_url") +LNBITS_API_KEY = get_cfg("lnbits_api_key") + +# Validate +for var_name, val in [ + ("discord_token", DISCORD_TOKEN), + ("guild_id", GUILD_ID), + ("lnbits_url", LNBITS_URL), + ("lnbits_api_key", LNBITS_API_KEY), +]: + if not val: + logger.critical(f"Config key '{var_name}' is missing or empty in DB.") + exit(1) + +# Build WebSocket URL +if LNBITS_URL.startswith("https://"): + ws_base = LNBITS_URL.replace("https://", "wss://", 1) +elif LNBITS_URL.startswith("http://"): + ws_base = LNBITS_URL.replace("http://", "ws://", 1) +else: + logger.critical("LNBITS_URL must start with http:// or https://") + exit(1) +LNBITS_WS_URL = f"{ws_base}/api/v1/ws/{LNBITS_API_KEY}" # ─── Discord Bot Setup ──────────────────────────────────────────────────────── intents = discord.Intents.default() intents.members = True - bot = commands.Bot(command_prefix="!", intents=intents) -pending_invoices = {} # payment_hash -> (user_id, interaction) -# ─── Role Assignment Handler ────────────────────────────────────────────────── -async def assign_role_after_payment(payment_hash: str, payment_obj: dict): - logger.debug(f"ENTER assign_role_after_payment for hash={payment_hash}") - data = pending_invoices.pop(payment_hash, None) - if data is None: - logger.info(f"No pending invoice for hash={payment_hash}") - return - user_id, interaction = data +# ─── Utility: expiry computation ─────────────────────────────────────────────── +def compute_expiry(paid_at: datetime, plan: Plan) -> datetime: + if plan.expiry_type == "calendar_month": + first = paid_at.replace(day=1, hour=0, minute=0, second=0, microsecond=0) + return first + relativedelta(months=1) + elif plan.expiry_type == "rolling_days": + return paid_at + relativedelta(days=plan.duration_days) + else: # fixed_date + return paid_at # assume UI sets exact expires_at elsewhere - # Convert from millisatoshis to sats - raw_msat = payment_obj.get("amount", 0) - sat_amount = raw_msat // 1000 - - # Record payment in DB - session = SessionLocal() - try: - payment = Payment( - payment_hash=payment_hash, - user_id=str(user_id), - amount=sat_amount, - paid_at=datetime.utcnow() +# ─── Handle a plan purchase ─────────────────────────────────────────────────── +async def handle_plan_purchase(interaction: discord.Interaction, plan: Plan): + invoice_ch = bot.get_channel(int(plan.channel_id)) + if not invoice_ch: + return await interaction.response.send_message( + "❌ Invoice channel not found.", ephemeral=True ) - session.add(payment) - session.commit() - logger.info(f"💾 Logged payment {payment_hash} for user {user_id} ({sat_amount} sats).") - except Exception: - session.rollback() - logger.exception("DB error saving payment") - finally: - session.close() - guild = bot.get_guild(GUILD_ID) - if not guild: - logger.error(f"Guild {GUILD_ID} not found") - return - - member = guild.get_member(user_id) or await guild.fetch_member(user_id) - role = guild.get_role(ROLE_ID) - channel = bot.get_channel(CHANNEL_ID) - - if not all([member, role, channel]): - logger.error("Member, role, or channel not found") - return - - if role not in member.roles: - logger.debug(f"Adding role '{role.name}' to {member.display_name}") - try: - await asyncio.wait_for( - member.add_roles(role, reason="Paid via Lightning"), - timeout=10 - ) - await channel.send( - f"🎉 {member.mention} has paid **{sat_amount} sats** " - f"and received the **{role.name}** role!" - ) - logger.info(f"✅ Role '{role.name}' assigned to {member.display_name}") - except Exception: - logger.exception("Error assigning role or sending confirmation") - else: - await channel.send( - f"ℹ️ {member.mention} paid again (hash={payment_hash}), " - f"role **{role.name}** already assigned." - ) - logger.info(f"User {member.display_name} already had role; notified channel.") - -# ─── LNbits WebSocket Listener ──────────────────────────────────────────────── -async def lnbits_ws_listener(): - await bot.wait_until_ready() - logger.info(f"👂 Connecting to LNbits WS at {LNBITS_WS_URL}") - while True: - try: - async with websockets.connect(LNBITS_WS_URL) as ws: - logger.info("✅ WS connected") - async for msg in ws: - logger.debug(f"WS ► {msg}") - try: - data = json.loads(msg) - pay = data.get("payment", {}) - hsh = pay.get("checking_id") or pay.get("payment_hash") - amt = pay.get("amount", 0) - if hsh and pay.get("status") == "success" and amt > 0: - logger.info(f"💡 Payment received hash={hsh}, amt={amt}") - bot.loop.create_task(assign_role_after_payment(hsh, pay)) - else: - logger.debug("Ignored WS update (not a paid invoice)") - except json.JSONDecodeError: - logger.warning("WS ► Non-JSON payload") - except Exception: - logger.exception("WS listener error; reconnecting in 10s") - await asyncio.sleep(10) - -# ─── Slash Command Registration ─────────────────────────────────────────────── -@bot.event -async def on_ready(): - logger.info(f"Logged in as {bot.user} (ID: {bot.user.id})") - try: - synced = await bot.tree.sync() - logger.info(f"✅ Synced {len(synced)} command(s).") - except Exception: - logger.exception("Failed to sync commands") - bot.loop.create_task(lnbits_ws_listener()) - -@bot.tree.command(name=COMMAND_NAME, description="Pay to get your role via Lightning") -async def dynamic_command(interaction: discord.Interaction): - invoice_ch = bot.get_channel(CHANNEL_ID) - if invoice_ch is None: - await interaction.response.send_message( - "❌ Invoice channel not found. Check your config.", - ephemeral=True - ) - return - - loop = asyncio.get_running_loop() invoice_data = { "out": False, - "amount": PRICE, - "memo": f"Role for {interaction.user.display_name}" + "amount": plan.price_sats, + "memo": plan.invoice_message } headers = { "X-Api-Key": LNBITS_API_KEY, "Content-Type": "application/json" } - # Create invoice + loop = asyncio.get_running_loop() def create_invoice(): resp = requests.post( f"{LNBITS_URL}/api/v1/payments", @@ -220,50 +159,159 @@ async def dynamic_command(interaction: discord.Interaction): return resp.json() try: - invoice_json = await loop.run_in_executor(None, create_invoice) + inv = await loop.run_in_executor(None, create_invoice) except Exception: - logger.exception("Error creating LNbits invoice") - await interaction.response.send_message( + logger.exception("Invoice creation failed") + return await interaction.response.send_message( "❌ Failed to generate invoice. Try again later.", ephemeral=True ) - return - payment_request = invoice_json["bolt11"] - payment_hash = invoice_json["payment_hash"] - pending_invoices[payment_hash] = (interaction.user.id, interaction) - logger.info(f"Stored pending invoice {payment_hash} for user {interaction.user.id}") + pay_req = inv["bolt11"] + pay_hash = inv["payment_hash"] + + # Store pending subscription + db = SessionLocal() + sub = Subscription( + plan_id = plan.id, + user_id = str(interaction.user.id), + invoice_hash = pay_hash, + invoice_request = pay_req, + raw_invoice_json= json.dumps(inv) + ) + db.add(sub) + db.commit() + db.close() + logger.info(f"Stored pending subscription {pay_hash} for user {interaction.user.id}") # Generate QR code - def make_qr_buffer(): - img = qrcode.make(payment_request) - buf = io.BytesIO() - img.save(buf, format="PNG") - buf.seek(0) - return buf + def make_qr_buf(): + img = requests.compat.BytesIO() + qrcode.make(pay_req).save(img, format="PNG") + img.seek(0) + return img - qr_buffer = await loop.run_in_executor(None, make_qr_buffer) - qr_file = File(qr_buffer, filename="invoice.png") + qr_buf = await loop.run_in_executor(None, make_qr_buf) + qr_file = File(qr_buf, filename="invoice.png") embed = Embed( - title="⚡ Please Pay ⚡", - description=INVOICE_MESSAGE_TEMPLATE, + title=f"⚡ Pay {plan.price_sats} sats for {plan.name} ⚡", + description=plan.invoice_message, color=discord.Color.gold() ) - embed.add_field(name="Invoice", value=f"```{payment_request}```", inline=False) - embed.add_field(name="Amount", value=f"{PRICE} sats", inline=True) + embed.add_field(name="Invoice", value=f"```{pay_req}```", inline=False) embed.set_image(url="attachment://invoice.png") - embed.set_footer(text=f"Hash: {payment_hash}") + embed.set_footer(text=f"Hash: {pay_hash}") - await invoice_ch.send( - content=interaction.user.mention, - embed=embed, - file=qr_file - ) + await invoice_ch.send(content=interaction.user.mention, embed=embed, file=qr_file) await interaction.response.send_message( f"✅ Invoice posted in {invoice_ch.mention}", ephemeral=True ) -# ─── Run Bot ────────────────────────────────────────────────────────────────── +# ─── WebSocket listener for payments ────────────────────────────────────────── +async def lnbits_ws_listener(): + await bot.wait_until_ready() + logger.info(f"Listening to LNbits WS at {LNBITS_WS_URL}") + while True: + try: + async with websockets.connect(LNBITS_WS_URL) as ws: + async for msg in ws: + try: + data = json.loads(msg) + pay = data.get("payment", {}) + hsh = pay.get("checking_id") or pay.get("payment_hash") + if not (hsh and pay.get("status") == "success"): + continue + + # Activate subscription + db = SessionLocal() + sub = db.query(Subscription)\ + .filter_by(invoice_hash=hsh, status="pending")\ + .first() + if not sub: + db.close() + continue + + paid_at = datetime.utcnow() + expires_at = compute_expiry(paid_at, sub.plan) + + sub.status = "active" + sub.paid_at = paid_at + sub.expires_at = expires_at + sub.raw_payment_json = json.dumps(pay) + db.commit() + db.close() + + # Assign role & announce + guild = bot.get_guild(GUILD_ID) + member = guild.get_member(int(sub.user_id)) or await guild.fetch_member(int(sub.user_id)) + role = guild.get_role(int(sub.plan.role_id)) + chan = bot.get_channel(int(sub.plan.channel_id)) + + await member.add_roles(role, reason="Subscription paid") + await chan.send( + f"🎉 {member.mention} paid **{sub.plan.price_sats} sats**" + f" for **{sub.plan.name}**! Expires {expires_at:%Y-%m-%d}." + ) + except Exception: + logger.exception("Error processing WS message") + except Exception: + logger.exception("WS connection error, retry in 10s") + await asyncio.sleep(10) + +# ─── Cleanup expired subscriptions ───────────────────────────────────────────── +async def cleanup_expired(): + now = datetime.utcnow() + db = SessionLocal() + expired = db.query(Subscription)\ + .filter(Subscription.status=="active", Subscription.expires_at <= now)\ + .all() + for sub in expired: + guild = bot.get_guild(GUILD_ID) + member = guild.get_member(int(sub.user_id)) + role = guild.get_role(int(sub.plan.role_id)) + if member and role in member.roles: + try: + await member.remove_roles(role, reason="Subscription expired") + sub.status = "expired" + sub.role_removed_at = now + db.commit() + logger.info(f"Removed {role.name} from {member.display_name}") + except Exception: + logger.exception("Failed to remove expired role") + db.close() + +# ─── on_ready: dynamic commands + start jobs ────────────────────────────────── +@bot.event +async def on_ready(): + logger.info(f"Logged in as {bot.user} (ID: {bot.user.id})") + + # Register commands for each plan + db = SessionLocal() + plans = db.query(Plan).all() + for plan in plans: + @bot.tree.command(name=plan.command_name, description=f"Buy {plan.name}") + async def _cmd(interaction: discord.Interaction, plan=plan): + await handle_plan_purchase(interaction, plan) + db.close() + await bot.tree.sync() + + # Start WS listener + bot.loop.create_task(lnbits_ws_listener()) + + # Schedule cleanup daily at midnight UTC + scheduler = AsyncIOScheduler(timezone="UTC") + scheduler.add_job(cleanup_expired, "cron", hour=0, minute=0) + scheduler.start() + +# ─── Entry point ─────────────────────────────────────────────────────────────── if __name__ == "__main__": + # Seed Config table with required keys if missing + session = SessionLocal() + for key in ("discord_token","guild_id","lnbits_url","lnbits_api_key"): + if not session.get(Config, key): + session.add(Config(key=key, value="")) + session.commit() + session.close() + bot.run(DISCORD_TOKEN)