Update discord_lnbits_bot.py

This commit is contained in:
saulteafarmer 2025-05-14 15:12:23 +00:00
parent c914dc1753
commit 10e51b3655

View File

@ -7,7 +7,7 @@ import requests
import logging import logging
from datetime import datetime from datetime import datetime
from dateutil.relativedelta import relativedelta from dateutil.relativedelta import relativedelta
from dotenv import load_dotenv import qrcode
import discord import discord
from discord import File, Embed from discord import File, Embed
@ -18,6 +18,7 @@ import websockets
from sqlalchemy import ( from sqlalchemy import (
create_engine, Column, String, Integer, DateTime, Text, Enum, ForeignKey create_engine, Column, String, Integer, DateTime, Text, Enum, ForeignKey
) )
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, relationship from sqlalchemy.orm import sessionmaker, relationship
@ -30,18 +31,16 @@ logging.basicConfig(
) )
logger = logging.getLogger("discord_lnbits_bot") logger = logging.getLogger("discord_lnbits_bot")
# ─── Load minimal .env & DB setup ───────────────────────────────────────────── # ─── DB Setup & Config Seeding ─────────────────────────────────────────────────
load_dotenv() # for DATABASE_URL only
DATABASE_URL = os.getenv("DATABASE_URL") DATABASE_URL = os.getenv("DATABASE_URL")
if not DATABASE_URL: if not DATABASE_URL:
logger.critical("Environment variable DATABASE_URL is missing.") logger.critical("Missing DATABASE_URL in .env")
exit(1) exit(1)
engine = create_engine(DATABASE_URL, future=True) engine = create_engine(DATABASE_URL, future=True)
SessionLocal = sessionmaker(bind=engine, expire_on_commit=False) SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)
Base = declarative_base() Base = declarative_base()
# ─── ORM Models ───────────────────────────────────────────────────────────────
class Config(Base): class Config(Base):
__tablename__ = "config" __tablename__ = "config"
key = Column(String, primary_key=True) key = Column(String, primary_key=True)
@ -56,7 +55,10 @@ class Plan(Base):
channel_id = Column(String(64), nullable=False) channel_id = Column(String(64), nullable=False)
price_sats = Column(Integer, nullable=False) price_sats = Column(Integer, nullable=False)
invoice_message = Column(Text, nullable=False) invoice_message = Column(Text, nullable=False)
expiry_type = Column(Enum("fixed_date","rolling_days","calendar_month", name="expiry_type"), nullable=False) expiry_type = Column(
Enum("fixed_date", "rolling_days", "calendar_month", name="expiry_type"),
nullable=False
)
duration_days = Column(Integer, default=30) duration_days = Column(Integer, default=30)
subs = relationship("Subscription", back_populates="plan") subs = relationship("Subscription", back_populates="plan")
@ -67,25 +69,36 @@ class Subscription(Base):
user_id = Column(String(64), nullable=False) user_id = Column(String(64), nullable=False)
invoice_hash = Column(String(256), unique=True, nullable=False) invoice_hash = Column(String(256), unique=True, nullable=False)
invoice_request = Column(Text, nullable=False) invoice_request = Column(Text, nullable=False)
status = Column(Enum("pending","active","expired","cancelled", name="sub_status"), nullable=False, default="pending") status = Column(
Enum("pending","active","expired","cancelled", name="sub_status"),
nullable=False, default="pending"
)
created_at = Column(DateTime, default=datetime.utcnow) created_at = Column(DateTime, default=datetime.utcnow)
paid_at = Column(DateTime) paid_at = Column(DateTime)
expires_at = Column(DateTime) expires_at = Column(DateTime)
role_assigned_at = Column(DateTime) role_assigned_at = Column(DateTime)
role_removed_at = Column(DateTime) role_removed_at = Column(DateTime)
raw_invoice_json = Column(Text) raw_invoice_json = Column(JSONB)
raw_payment_json = Column(Text) raw_payment_json = Column(JSONB)
plan = relationship("Plan", back_populates="subs") plan = relationship("Plan", back_populates="subs")
# create tables & seed Config keys
Base.metadata.create_all(engine) Base.metadata.create_all(engine)
logger.info("✅ Database schema ready.") logger.info("✅ Database schema ready.")
session = SessionLocal()
for key in ("discord_token", "guild_id", "lnbits_url", "lnbits_api_key"):
if not session.get(Config, key):
session.add(Config(key=key, value=""))
session.commit()
session.close()
# ─── Config helper ──────────────────────────────────────────────────────────── # ─── Config helper ────────────────────────────────────────────────────────────
def get_cfg(key: str) -> str: def get_cfg(key: str) -> str:
session = SessionLocal() db = SessionLocal()
cfg = session.get(Config, key) row = db.get(Config, key)
session.close() db.close()
return cfg.value if cfg else "" return row.value if row else ""
# Read core settings from DB # Read core settings from DB
DISCORD_TOKEN = get_cfg("discord_token") DISCORD_TOKEN = get_cfg("discord_token")
@ -93,15 +106,14 @@ GUILD_ID = int(get_cfg("guild_id") or 0)
LNBITS_URL = get_cfg("lnbits_url") LNBITS_URL = get_cfg("lnbits_url")
LNBITS_API_KEY = get_cfg("lnbits_api_key") LNBITS_API_KEY = get_cfg("lnbits_api_key")
# Validate for name, val in (
for var_name, val in [
("discord_token", DISCORD_TOKEN), ("discord_token", DISCORD_TOKEN),
("guild_id", GUILD_ID), ("guild_id", GUILD_ID),
("lnbits_url", LNBITS_URL), ("lnbits_url", LNBITS_URL),
("lnbits_api_key", LNBITS_API_KEY), ("lnbits_api_key", LNBITS_API_KEY),
]: ):
if not val: if not val:
logger.critical(f"Config key '{var_name}' is missing or empty in DB.") logger.critical(f"Config '{name}' is empty in DB. Fill it via the UI first.")
exit(1) exit(1)
# Build WebSocket URL # Build WebSocket URL
@ -119,17 +131,16 @@ intents = discord.Intents.default()
intents.members = True intents.members = True
bot = commands.Bot(command_prefix="!", intents=intents) bot = commands.Bot(command_prefix="!", intents=intents)
# ─── Utility: expiry computation ─────────────────────────────────────────────── # ─── Expiry computation ────────────────────────────────────────────────────────
def compute_expiry(paid_at: datetime, plan: Plan) -> datetime: def compute_expiry(paid_at: datetime, plan: Plan) -> datetime:
if plan.expiry_type == "calendar_month": if plan.expiry_type == "calendar_month":
first = paid_at.replace(day=1, hour=0, minute=0, second=0, microsecond=0) first = paid_at.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
return first + relativedelta(months=1) return first + relativedelta(months=1)
elif plan.expiry_type == "rolling_days": if plan.expiry_type == "rolling_days":
return paid_at + relativedelta(days=plan.duration_days) return paid_at + relativedelta(days=plan.duration_days)
else: # fixed_date return paid_at # fixed_date handled via UI
return paid_at # assume UI sets exact expires_at elsewhere
# ─── Handle a plan purchase ─────────────────────────────────────────────────── # ─── Handle plan purchase ─────────────────────────────────────────────────────
async def handle_plan_purchase(interaction: discord.Interaction, plan: Plan): async def handle_plan_purchase(interaction: discord.Interaction, plan: Plan):
invoice_ch = bot.get_channel(int(plan.channel_id)) invoice_ch = bot.get_channel(int(plan.channel_id))
if not invoice_ch: if not invoice_ch:
@ -149,47 +160,40 @@ async def handle_plan_purchase(interaction: discord.Interaction, plan: Plan):
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
def create_invoice(): def create_invoice():
resp = requests.post( r = requests.post(
f"{LNBITS_URL}/api/v1/payments", f"{LNBITS_URL}/api/v1/payments",
json=invoice_data, json=invoice_data, headers=headers, timeout=10
headers=headers,
timeout=10
) )
resp.raise_for_status() r.raise_for_status()
return resp.json() return r.json()
try: try:
inv = await loop.run_in_executor(None, create_invoice) inv = await loop.run_in_executor(None, create_invoice)
except Exception: except Exception:
logger.exception("Invoice creation failed") logger.exception("Invoice creation failed")
return await interaction.response.send_message( return await interaction.response.send_message(
"❌ Failed to generate invoice. Try again later.", "❌ Failed to generate invoice.", ephemeral=True
ephemeral=True
) )
pay_req = inv["bolt11"] pay_req = inv["bolt11"]
pay_hash = inv["payment_hash"] pay_hash = inv["payment_hash"]
# Store pending subscription
db = SessionLocal() db = SessionLocal()
sub = Subscription( sub = Subscription(
plan_id = plan.id, plan_id = plan.id,
user_id = str(interaction.user.id), user_id = str(interaction.user.id),
invoice_hash = pay_hash, invoice_hash = pay_hash,
invoice_request = pay_req, invoice_request = pay_req,
raw_invoice_json= json.dumps(inv) raw_invoice_json= inv
) )
db.add(sub) db.add(sub); db.commit(); db.close()
db.commit() logger.info(f"Stored pending {pay_hash} for user {interaction.user.id}")
db.close()
logger.info(f"Stored pending subscription {pay_hash} for user {interaction.user.id}")
# Generate QR code
def make_qr_buf(): def make_qr_buf():
img = requests.compat.BytesIO() buf = io.BytesIO()
qrcode.make(pay_req).save(img, format="PNG") qrcode.make(pay_req).save(buf, format="PNG")
img.seek(0) buf.seek(0)
return img return buf
qr_buf = await loop.run_in_executor(None, make_qr_buf) qr_buf = await loop.run_in_executor(None, make_qr_buf)
qr_file = File(qr_buf, filename="invoice.png") qr_file = File(qr_buf, filename="invoice.png")
@ -208,7 +212,7 @@ async def handle_plan_purchase(interaction: discord.Interaction, plan: Plan):
f"✅ Invoice posted in {invoice_ch.mention}", ephemeral=True f"✅ Invoice posted in {invoice_ch.mention}", ephemeral=True
) )
# ─── WebSocket listener for payments ────────────────────────────────────────── # ─── WS listener ───────────────────────────────────────────────────────────────
async def lnbits_ws_listener(): async def lnbits_ws_listener():
await bot.wait_until_ready() await bot.wait_until_ready()
logger.info(f"Listening to LNbits WS at {LNBITS_WS_URL}") logger.info(f"Listening to LNbits WS at {LNBITS_WS_URL}")
@ -216,57 +220,50 @@ async def lnbits_ws_listener():
try: try:
async with websockets.connect(LNBITS_WS_URL) as ws: async with websockets.connect(LNBITS_WS_URL) as ws:
async for msg in ws: async for msg in ws:
try: data = json.loads(msg)
data = json.loads(msg) pay = data.get("payment", {})
pay = data.get("payment", {}) hsh = pay.get("checking_id") or pay.get("payment_hash")
hsh = pay.get("checking_id") or pay.get("payment_hash") if not (hsh and pay.get("status")=="success"):
if not (hsh and pay.get("status") == "success"): continue
continue
# Activate subscription db = SessionLocal()
db = SessionLocal() sub = db.query(Subscription)\
sub = db.query(Subscription)\ .filter_by(invoice_hash=hsh, status="pending")\
.filter_by(invoice_hash=hsh, status="pending")\ .first()
.first() if not sub:
if not sub: db.close(); continue
db.close()
continue
paid_at = datetime.utcnow() paid_at = datetime.utcnow()
expires_at = compute_expiry(paid_at, sub.plan) expires_at = compute_expiry(paid_at, sub.plan)
sub.status = "active" sub.status = "active"
sub.paid_at = paid_at sub.paid_at = paid_at
sub.expires_at = expires_at sub.expires_at = expires_at
sub.raw_payment_json = json.dumps(pay) sub.raw_payment_json = pay
db.commit() db.commit(); db.close()
db.close()
# Assign role & announce guild = bot.get_guild(GUILD_ID)
guild = bot.get_guild(GUILD_ID) member = guild.get_member(int(sub.user_id)) or await guild.fetch_member(int(sub.user_id))
member = guild.get_member(int(sub.user_id)) or await guild.fetch_member(int(sub.user_id)) role = guild.get_role(int(sub.plan.role_id))
role = guild.get_role(int(sub.plan.role_id)) chan = bot.get_channel(int(sub.plan.channel_id))
chan = bot.get_channel(int(sub.plan.channel_id))
await member.add_roles(role, reason="Subscription paid") await member.add_roles(role, reason="Subscription paid")
await chan.send( await chan.send(
f"🎉 {member.mention} paid **{sub.plan.price_sats} sats**" f"🎉 {member.mention} paid **{sub.plan.price_sats} sats** "
f" for **{sub.plan.name}**! Expires {expires_at:%Y-%m-%d}." f"for **{sub.plan.name}**, expires {expires_at:%Y-%m-%d}."
) )
except Exception:
logger.exception("Error processing WS message")
except Exception: except Exception:
logger.exception("WS connection error, retry in 10s") logger.exception("WS error; reconnecting in 10s")
await asyncio.sleep(10) await asyncio.sleep(10)
# ─── Cleanup expired subscriptions ───────────────────────────────────────────── # ─── Cleanup expired ───────────────────────────────────────────────────────────
async def cleanup_expired(): async def cleanup_expired():
now = datetime.utcnow() now = datetime.utcnow()
db = SessionLocal() db = SessionLocal()
expired = db.query(Subscription)\ rows = db.query(Subscription)\
.filter(Subscription.status=="active", Subscription.expires_at <= now)\ .filter(Subscription.status=="active", Subscription.expires_at <= now)\
.all() .all()
for sub in expired: for sub in rows:
guild = bot.get_guild(GUILD_ID) guild = bot.get_guild(GUILD_ID)
member = guild.get_member(int(sub.user_id)) member = guild.get_member(int(sub.user_id))
role = guild.get_role(int(sub.plan.role_id)) role = guild.get_role(int(sub.plan.role_id))
@ -276,42 +273,35 @@ async def cleanup_expired():
sub.status = "expired" sub.status = "expired"
sub.role_removed_at = now sub.role_removed_at = now
db.commit() db.commit()
logger.info(f"Removed {role.name} from {member.display_name}") logger.info(f"Removed expired {role.name} from {member.display_name}")
except Exception: except Exception:
logger.exception("Failed to remove expired role") logger.exception("Error removing expired role")
db.close() db.close()
# ─── on_ready: dynamic commands + start jobs ────────────────────────────────── # ─── on_ready & scheduler ─────────────────────────────────────────────────────
@bot.event @bot.event
async def on_ready(): async def on_ready():
logger.info(f"Logged in as {bot.user} (ID: {bot.user.id})") logger.info(f"Bot logged in as {bot.user} (ID {bot.user.id})")
# Register commands for each plan # dynamic command registration
db = SessionLocal() db = SessionLocal()
plans = db.query(Plan).all() plans = db.query(Plan).all()
for plan in plans: for plan in plans:
@bot.tree.command(name=plan.command_name, description=f"Buy {plan.name}") @bot.tree.command(name=plan.command_name,
description=f"Buy {plan.name} for {plan.price_sats} sats")
async def _cmd(interaction: discord.Interaction, plan=plan): async def _cmd(interaction: discord.Interaction, plan=plan):
await handle_plan_purchase(interaction, plan) await handle_plan_purchase(interaction, plan)
db.close() db.close()
await bot.tree.sync() await bot.tree.sync()
# Start WS listener # start WS listener
bot.loop.create_task(lnbits_ws_listener()) bot.loop.create_task(lnbits_ws_listener())
# Schedule cleanup daily at midnight UTC # schedule midnightUTC expiry check
scheduler = AsyncIOScheduler(timezone="UTC") sched = AsyncIOScheduler(timezone="UTC")
scheduler.add_job(cleanup_expired, "cron", hour=0, minute=0) sched.add_job(cleanup_expired, "cron", hour=0, minute=0)
scheduler.start() sched.start()
# ─── Entry point ─────────────────────────────────────────────────────────────── # ─── Run ───────────────────────────────────────────────────────────────────────
if __name__ == "__main__": if __name__ == "__main__":
# Seed Config table with required keys if missing
session = SessionLocal()
for key in ("discord_token","guild_id","lnbits_url","lnbits_api_key"):
if not session.get(Config, key):
session.add(Config(key=key, value=""))
session.commit()
session.close()
bot.run(DISCORD_TOKEN) bot.run(DISCORD_TOKEN)