discord-lnbits-bot/discord_lnbits_bot.py

308 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
import os
import asyncio
import json
import io
import requests
import logging
from datetime import datetime
from dateutil.relativedelta import relativedelta
import qrcode
import discord
from discord import File, Embed
from discord.ext import commands
import websockets
from sqlalchemy import (
create_engine, Column, String, Integer, DateTime, Text, Enum, ForeignKey
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, relationship
from apscheduler.schedulers.asyncio import AsyncIOScheduler
# ─── Logging ──────────────────────────────────────────────────────────────────
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s | %(message)s"
)
logger = logging.getLogger("discord_lnbits_bot")
# ─── DB Setup & Config Seeding ─────────────────────────────────────────────────
DATABASE_URL = os.getenv("DATABASE_URL")
if not DATABASE_URL:
logger.critical("Missing DATABASE_URL in .env")
exit(1)
engine = create_engine(DATABASE_URL, future=True)
SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)
Base = declarative_base()
class Config(Base):
__tablename__ = "config"
key = Column(String, primary_key=True)
value = Column(Text, nullable=False)
class Plan(Base):
__tablename__ = "plans"
id = Column(Integer, primary_key=True)
name = Column(Text, nullable=False, unique=True)
command_name = Column(Text, nullable=False, unique=True)
role_id = Column(String(64), nullable=False)
channel_id = Column(String(64), nullable=False)
price_sats = Column(Integer, nullable=False)
invoice_message = Column(Text, nullable=False)
expiry_type = Column(
Enum("fixed_date", "rolling_days", "calendar_month", name="expiry_type"),
nullable=False
)
duration_days = Column(Integer, default=30)
subs = relationship("Subscription", back_populates="plan")
class Subscription(Base):
__tablename__ = "subscriptions"
id = Column(Integer, primary_key=True)
plan_id = Column(Integer, ForeignKey("plans.id"), nullable=False)
user_id = Column(String(64), nullable=False)
invoice_hash = Column(String(256), unique=True, nullable=False)
invoice_request = Column(Text, nullable=False)
status = Column(
Enum("pending","active","expired","cancelled", name="sub_status"),
nullable=False, default="pending"
)
created_at = Column(DateTime, default=datetime.utcnow)
paid_at = Column(DateTime)
expires_at = Column(DateTime)
role_assigned_at = Column(DateTime)
role_removed_at = Column(DateTime)
raw_invoice_json = Column(JSONB)
raw_payment_json = Column(JSONB)
plan = relationship("Plan", back_populates="subs")
# create tables & seed Config keys
Base.metadata.create_all(engine)
logger.info("✅ Database schema ready.")
session = SessionLocal()
for key in ("discord_token", "guild_id", "lnbits_url", "lnbits_api_key"):
if not session.get(Config, key):
session.add(Config(key=key, value=""))
session.commit()
session.close()
# ─── Config helper ────────────────────────────────────────────────────────────
def get_cfg(key: str) -> str:
db = SessionLocal()
row = db.get(Config, key)
db.close()
return row.value if row else ""
# Read core settings from DB
DISCORD_TOKEN = get_cfg("discord_token")
GUILD_ID = int(get_cfg("guild_id") or 0)
LNBITS_URL = get_cfg("lnbits_url")
LNBITS_API_KEY = get_cfg("lnbits_api_key")
for name, val in (
("discord_token", DISCORD_TOKEN),
("guild_id", GUILD_ID),
("lnbits_url", LNBITS_URL),
("lnbits_api_key", LNBITS_API_KEY),
):
if not val:
logger.critical(f"Config '{name}' is empty in DB. Fill it via the UI first.")
exit(1)
# Build WebSocket URL
if LNBITS_URL.startswith("https://"):
ws_base = LNBITS_URL.replace("https://", "wss://", 1)
elif LNBITS_URL.startswith("http://"):
ws_base = LNBITS_URL.replace("http://", "ws://", 1)
else:
logger.critical("LNBITS_URL must start with http:// or https://")
exit(1)
LNBITS_WS_URL = f"{ws_base}/api/v1/ws/{LNBITS_API_KEY}"
# ─── Discord Bot Setup ────────────────────────────────────────────────────────
intents = discord.Intents.default()
intents.members = True
bot = commands.Bot(command_prefix="!", intents=intents)
# ─── Expiry computation ────────────────────────────────────────────────────────
def compute_expiry(paid_at: datetime, plan: Plan) -> datetime:
if plan.expiry_type == "calendar_month":
first = paid_at.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
return first + relativedelta(months=1)
if plan.expiry_type == "rolling_days":
return paid_at + relativedelta(days=plan.duration_days)
return paid_at # fixed_date handled via UI
# ─── Handle plan purchase ─────────────────────────────────────────────────────
async def handle_plan_purchase(interaction: discord.Interaction, plan: Plan):
invoice_ch = bot.get_channel(int(plan.channel_id))
if not invoice_ch:
return await interaction.response.send_message(
"❌ Invoice channel not found.", ephemeral=True
)
invoice_data = {
"out": False,
"amount": plan.price_sats,
"memo": plan.invoice_message
}
headers = {
"X-Api-Key": LNBITS_API_KEY,
"Content-Type": "application/json"
}
loop = asyncio.get_running_loop()
def create_invoice():
r = requests.post(
f"{LNBITS_URL}/api/v1/payments",
json=invoice_data, headers=headers, timeout=10
)
r.raise_for_status()
return r.json()
try:
inv = await loop.run_in_executor(None, create_invoice)
except Exception:
logger.exception("Invoice creation failed")
return await interaction.response.send_message(
"❌ Failed to generate invoice.", ephemeral=True
)
pay_req = inv["bolt11"]
pay_hash = inv["payment_hash"]
db = SessionLocal()
sub = Subscription(
plan_id = plan.id,
user_id = str(interaction.user.id),
invoice_hash = pay_hash,
invoice_request = pay_req,
raw_invoice_json= inv
)
db.add(sub); db.commit(); db.close()
logger.info(f"Stored pending {pay_hash} for user {interaction.user.id}")
def make_qr_buf():
buf = io.BytesIO()
qrcode.make(pay_req).save(buf, format="PNG")
buf.seek(0)
return buf
qr_buf = await loop.run_in_executor(None, make_qr_buf)
qr_file = File(qr_buf, filename="invoice.png")
embed = Embed(
title=f"⚡ Pay {plan.price_sats} sats for {plan.name}",
description=plan.invoice_message,
color=discord.Color.gold()
)
embed.add_field(name="Invoice", value=f"```{pay_req}```", inline=False)
embed.set_image(url="attachment://invoice.png")
embed.set_footer(text=f"Hash: {pay_hash}")
await invoice_ch.send(content=interaction.user.mention, embed=embed, file=qr_file)
await interaction.response.send_message(
f"✅ Invoice posted in {invoice_ch.mention}", ephemeral=True
)
# ─── WS listener ───────────────────────────────────────────────────────────────
async def lnbits_ws_listener():
await bot.wait_until_ready()
logger.info(f"Listening to LNbits WS at {LNBITS_WS_URL}")
while True:
try:
async with websockets.connect(LNBITS_WS_URL) as ws:
async for msg in ws:
data = json.loads(msg)
pay = data.get("payment", {})
hsh = pay.get("checking_id") or pay.get("payment_hash")
if not (hsh and pay.get("status")=="success"):
continue
db = SessionLocal()
sub = db.query(Subscription)\
.filter_by(invoice_hash=hsh, status="pending")\
.first()
if not sub:
db.close(); continue
paid_at = datetime.utcnow()
expires_at = compute_expiry(paid_at, sub.plan)
sub.status = "active"
sub.paid_at = paid_at
sub.expires_at = expires_at
sub.raw_payment_json = pay
db.commit(); db.close()
guild = bot.get_guild(GUILD_ID)
member = guild.get_member(int(sub.user_id)) or await guild.fetch_member(int(sub.user_id))
role = guild.get_role(int(sub.plan.role_id))
chan = bot.get_channel(int(sub.plan.channel_id))
await member.add_roles(role, reason="Subscription paid")
await chan.send(
f"🎉 {member.mention} paid **{sub.plan.price_sats} sats** "
f"for **{sub.plan.name}**, expires {expires_at:%Y-%m-%d}."
)
except Exception:
logger.exception("WS error; reconnecting in 10s")
await asyncio.sleep(10)
# ─── Cleanup expired ───────────────────────────────────────────────────────────
async def cleanup_expired():
now = datetime.utcnow()
db = SessionLocal()
rows = db.query(Subscription)\
.filter(Subscription.status=="active", Subscription.expires_at <= now)\
.all()
for sub in rows:
guild = bot.get_guild(GUILD_ID)
member = guild.get_member(int(sub.user_id))
role = guild.get_role(int(sub.plan.role_id))
if member and role in member.roles:
try:
await member.remove_roles(role, reason="Subscription expired")
sub.status = "expired"
sub.role_removed_at = now
db.commit()
logger.info(f"Removed expired {role.name} from {member.display_name}")
except Exception:
logger.exception("Error removing expired role")
db.close()
# ─── on_ready & scheduler ─────────────────────────────────────────────────────
@bot.event
async def on_ready():
logger.info(f"Bot logged in as {bot.user} (ID {bot.user.id})")
# dynamic command registration
db = SessionLocal()
plans = db.query(Plan).all()
for plan in plans:
@bot.tree.command(name=plan.command_name,
description=f"Buy {plan.name} for {plan.price_sats} sats")
async def _cmd(interaction: discord.Interaction, plan=plan):
await handle_plan_purchase(interaction, plan)
db.close()
await bot.tree.sync()
# start WS listener
bot.loop.create_task(lnbits_ws_listener())
# schedule midnightUTC expiry check
sched = AsyncIOScheduler(timezone="UTC")
sched.add_job(cleanup_expired, "cron", hour=0, minute=0)
sched.start()
# ─── Run ───────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
bot.run(DISCORD_TOKEN)