discord-lnbits-bot/discord_lnbits_bot.py

226 lines
9.9 KiB
Python

#!/usr/bin/env python3
import os
import asyncio
import json
import io
import requests
import logging
from datetime import datetime
from dateutil.relativedelta import relativedelta
import qrcode
import discord
from discord import File, Embed
from discord.ext import commands
import websockets
from sqlalchemy import (
create_engine, Column, String, Integer, DateTime, Text, Enum, ForeignKey
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, relationship
from apscheduler.schedulers.asyncio import AsyncIOScheduler
# ─── Logging ──────────────────────────────────────────────────────────────
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s | %(message)s")
logger = logging.getLogger("discord_lnbits_bot")
# ─── Database & Config Seeding ────────────────────────────────────────────
DATABASE_URL = os.environ["DATABASE_URL"]
engine = create_engine(DATABASE_URL, future=True)
SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)
Base = declarative_base()
class Config(Base):
__tablename__ = "config"
key = Column(String, primary_key=True)
value = Column(Text, nullable=False)
class Plan(Base):
__tablename__ = "plans"
id = Column(Integer, primary_key=True)
name = Column(Text, nullable=False, unique=True)
command_name = Column(Text, nullable=False, unique=True)
role_id = Column(String(64), nullable=False)
channel_id = Column(String(64), nullable=False)
price_sats = Column(Integer, nullable=False)
invoice_message = Column(Text, nullable=False)
expiry_type = Column(Enum("fixed_date","rolling_days","calendar_month", name="expiry_type"), nullable=False)
duration_days = Column(Integer, default=30)
subs = relationship("Subscription", back_populates="plan")
class Subscription(Base):
__tablename__ = "subscriptions"
id = Column(Integer, primary_key=True)
plan_id = Column(Integer, ForeignKey("plans.id"), nullable=False)
user_id = Column(String(64), nullable=False)
invoice_hash = Column(String(256), unique=True, nullable=False)
invoice_request = Column(Text, nullable=False)
status = Column(Enum("pending","active","expired","cancelled", name="sub_status"), nullable=False, default="pending")
created_at = Column(DateTime, default=datetime.utcnow)
paid_at = Column(DateTime)
expires_at = Column(DateTime)
role_assigned_at = Column(DateTime)
role_removed_at = Column(DateTime)
raw_invoice_json = Column(JSONB)
raw_payment_json = Column(JSONB)
plan = relationship("Plan", back_populates="subs")
Base.metadata.create_all(engine)
logger.info("✅ Database schema ready.")
# Seed Config keys (if missing)
session = SessionLocal()
for key in ("discord_token", "guild_id", "lnbits_url", "lnbits_api_key"):
if not session.get(Config, key):
session.add(Config(key=key, value=""))
session.commit()
session.close()
# ─── Config helper ─────────────────────────────────────────────────────────
def get_cfg(k: str) -> str:
db = SessionLocal()
row = db.get(Config, k)
db.close()
return row.value if row else ""
DISCORD_TOKEN = get_cfg("discord_token")
GUILD_ID = int(get_cfg("guild_id") or 0)
LNBITS_URL = get_cfg("lnbits_url")
LNBITS_API_KEY = get_cfg("lnbits_api_key")
for name, val in (("discord_token", DISCORD_TOKEN), ("guild_id", GUILD_ID),
("lnbits_url", LNBITS_URL), ("lnbits_api_key", LNBITS_API_KEY)):
if not val:
logger.critical(f"Config '{name}' is empty in DB. Fill via UI first.")
exit(1)
# Build WS URL
if LNBITS_URL.startswith("https://"):
ws_base = LNBITS_URL.replace("https://", "wss://", 1)
elif LNBITS_URL.startswith("http://"):
ws_base = LNBITS_URL.replace("http://", "ws://", 1)
else:
logger.critical("LNBITS_URL must start with http:// or https://")
exit(1)
LNBITS_WS_URL = f"{ws_base}/api/v1/ws/{LNBITS_API_KEY}"
# ─── Discord Bot ──────────────────────────────────────────────────────────
intents = discord.Intents.default(); intents.members = True
bot = commands.Bot(command_prefix="!", intents=intents)
def compute_expiry(paid_at: datetime, plan: Plan) -> datetime:
if plan.expiry_type == "calendar_month":
first = paid_at.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
return first + relativedelta(months=1)
if plan.expiry_type == "rolling_days":
return paid_at + relativedelta(days=plan.duration_days)
return paid_at
async def handle_plan_purchase(interaction, plan: Plan):
invoice_ch = bot.get_channel(int(plan.channel_id))
if not invoice_ch:
return await interaction.response.send_message("❌ Invoice channel not found.", ephemeral=True)
invoice_data = {"out": False, "amount": plan.price_sats, "memo": plan.invoice_message}
headers = {"X-Api-Key": LNBITS_API_KEY, "Content-Type": "application/json"}
loop = asyncio.get_running_loop()
def create_invoice():
r = requests.post(f"{LNBITS_URL}/api/v1/payments", json=invoice_data, headers=headers, timeout=10)
r.raise_for_status()
return r.json()
try:
inv = await loop.run_in_executor(None, create_invoice)
except Exception:
logger.exception("Invoice creation failed")
return await interaction.response.send_message("❌ Could not generate invoice.", ephemeral=True)
pay_req, pay_hash = inv["bolt11"], inv["payment_hash"]
db = SessionLocal()
sub = Subscription(plan_id=plan.id, user_id=str(interaction.user.id),
invoice_hash=pay_hash, invoice_request=pay_req,
raw_invoice_json=inv)
db.add(sub); db.commit(); db.close()
def make_qr_buf():
buf = io.BytesIO(); qrcode.make(pay_req).save(buf, format="PNG"); buf.seek(0); return buf
qr_buf = await loop.run_in_executor(None, make_qr_buf)
qr_file= File(qr_buf, "invoice.png")
embed = Embed(title=f"⚡ Pay {plan.price_sats} sats for {plan.name}",
description=plan.invoice_message, color=discord.Color.gold())
embed.add_field(name="Invoice", value=f"```{pay_req}```", inline=False)
embed.set_image(url="attachment://invoice.png")
embed.set_footer(text=f"Hash: {pay_hash}")
await invoice_ch.send(content=interaction.user.mention, embed=embed, file=qr_file)
await interaction.response.send_message(f"✅ Invoice posted in {invoice_ch.mention}", ephemeral=True)
async def lnbits_ws_listener():
await bot.wait_until_ready()
logger.info(f"Listening on WS {LNBITS_WS_URL}")
while True:
try:
async with websockets.connect(LNBITS_WS_URL) as ws:
async for msg in ws:
data = json.loads(msg); pay = data.get("payment", {})
hsh = pay.get("checking_id") or pay.get("payment_hash")
if not (hsh and pay.get("status")=="success"): continue
db = SessionLocal()
sub= db.query(Subscription).filter_by(invoice_hash=hsh, status="pending").first()
if not sub: db.close(); continue
paid_at= datetime.utcnow()
expires_at= compute_expiry(paid_at, sub.plan)
sub.status=sub.paid_at=paid_at; sub.expires_at=expires_at
sub.raw_payment_json=pay; db.commit(); db.close()
guild = bot.get_guild(GUILD_ID)
member = guild.get_member(int(sub.user_id)) or await guild.fetch_member(int(sub.user_id))
role = guild.get_role(int(sub.plan.role_id))
chan = bot.get_channel(int(sub.plan.channel_id))
await member.add_roles(role, reason="Subscription paid")
await chan.send(f"🎉 {member.mention} paid **{sub.plan.price_sats} sats** "
f"for **{sub.plan.name}**, expires {expires_at:%Y-%m-%d}.")
except Exception:
logger.exception("WS error, reconnecting in 10s"); await asyncio.sleep(10)
async def cleanup_expired():
now= datetime.utcnow(); db= SessionLocal()
rows= db.query(Subscription).filter(Subscription.status=="active", Subscription.expires_at<=now).all()
for sub in rows:
guild = bot.get_guild(GUILD_ID)
member = guild.get_member(int(sub.user_id))
role = guild.get_role(int(sub.plan.role_id))
if member and role in member.roles:
try:
await member.remove_roles(role, reason="Subscription expired")
sub.status="expired"; sub.role_removed_at=now; db.commit()
logger.info(f"Removed expired role from {member.display_name}")
except Exception:
logger.exception("Failed to remove expired role")
db.close()
@bot.event
async def on_ready():
logger.info(f"Logged in as {bot.user}")
db = SessionLocal(); plans = db.query(Plan).all()
for plan in plans:
@bot.tree.command(name=plan.command_name, description=f"Buy {plan.name}")
async def _cmd(interaction, plan=plan):
await handle_plan_purchase(interaction, plan)
db.close()
await bot.tree.sync()
bot.loop.create_task(lnbits_ws_listener())
sched= AsyncIOScheduler(timezone="UTC")
sched.add_job(cleanup_expired, "cron", hour=0, minute=0)
sched.start()
if __name__=="__main__":
bot.run(DISCORD_TOKEN)