Update discord_lnbits_bot.py

This commit is contained in:
saulteafarmer 2025-05-14 14:59:35 +00:00
parent 29224cf51c
commit 4edc7b673f

View File

@ -3,10 +3,10 @@ import os
import asyncio import asyncio
import json import json
import io import io
import qrcode
import requests import requests
import logging import logging
from datetime import datetime from datetime import datetime
from dateutil.relativedelta import relativedelta
from dotenv import load_dotenv from dotenv import load_dotenv
import discord import discord
@ -16,199 +16,138 @@ from discord.ext import commands
import websockets import websockets
from sqlalchemy import ( from sqlalchemy import (
create_engine, Column, String, Integer, DateTime create_engine, Column, String, Integer, DateTime, Text, Enum, ForeignKey
) )
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker, relationship
# ─── Environment & Logging ──────────────────────────────────────────────────── from apscheduler.schedulers.asyncio import AsyncIOScheduler
load_dotenv() # loads .env into os.environ
# ─── Logging ──────────────────────────────────────────────────────────────────
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format="%(asctime)s %(levelname)7s %(name)s | %(message)s" format="%(asctime)s %(levelname)s %(name)s | %(message)s"
) )
logger = logging.getLogger("discord_lnbits_bot") logger = logging.getLogger("discord_lnbits_bot")
# ─── Config from ENV ────────────────────────────────────────────────────────── # ─── Load minimal .env & DB setup ─────────────────────────────────────────────
DISCORD_TOKEN = os.getenv("DISCORD_TOKEN") load_dotenv() # for DATABASE_URL only
GUILD_ID = int(os.getenv("GUILD_ID", 0)) DATABASE_URL = os.getenv("DATABASE_URL")
ROLE_ID = int(os.getenv("ROLE_ID", 0)) if not DATABASE_URL:
LNBITS_URL = os.getenv("LNBITS_URL", "").rstrip("/") logger.critical("Environment variable DATABASE_URL is missing.")
LNBITS_API_KEY = os.getenv("LNBITS_API_KEY", "")
PRICE = int(os.getenv("PRICE", 0))
CHANNEL_ID = int(os.getenv("CHANNEL_ID", 0))
INVOICE_MESSAGE_TEMPLATE = os.getenv("INVOICE_MESSAGE", "Invoice for your purchase.")
COMMAND_NAME = os.getenv("COMMAND_NAME", "support")
DATABASE_URL = os.getenv("DATABASE_URL")
# Ensure all required env vars are present
for var in (
"DISCORD_TOKEN", "GUILD_ID", "ROLE_ID",
"LNBITS_URL", "LNBITS_API_KEY", "PRICE",
"CHANNEL_ID", "DATABASE_URL"
):
if not globals()[var]:
logger.critical(f"Environment variable {var} is missing.")
exit(1)
# Build WS URL from LNBITS_URL
if LNBITS_URL.startswith("https://"):
base_ws = LNBITS_URL.replace("https://", "wss://", 1)
elif LNBITS_URL.startswith("http://"):
base_ws = LNBITS_URL.replace("http://", "ws://", 1)
else:
logger.critical("LNBITS_URL must start with http:// or https://")
exit(1) exit(1)
LNBITS_WS_URL = f"{base_ws}/api/v1/ws/{LNBITS_API_KEY}"
# ─── Database Setup ───────────────────────────────────────────────────────────
Base = declarative_base()
class Payment(Base):
__tablename__ = "payments"
id = Column(Integer, primary_key=True, autoincrement=True)
payment_hash = Column(String(256), unique=True, nullable=False, index=True)
user_id = Column(String(64), nullable=False, index=True)
amount = Column(Integer, nullable=False)
paid_at = Column(DateTime, default=datetime.utcnow, nullable=False)
engine = create_engine(DATABASE_URL, future=True) engine = create_engine(DATABASE_URL, future=True)
SessionLocal = sessionmaker(bind=engine, expire_on_commit=False) SessionLocal = sessionmaker(bind=engine, expire_on_commit=False)
Base = declarative_base()
# ─── ORM Models ───────────────────────────────────────────────────────────────
class Config(Base):
__tablename__ = "config"
key = Column(String, primary_key=True)
value = Column(Text, nullable=False)
class Plan(Base):
__tablename__ = "plans"
id = Column(Integer, primary_key=True)
name = Column(Text, nullable=False, unique=True)
command_name = Column(Text, nullable=False, unique=True)
role_id = Column(String(64), nullable=False)
channel_id = Column(String(64), nullable=False)
price_sats = Column(Integer, nullable=False)
invoice_message = Column(Text, nullable=False)
expiry_type = Column(Enum("fixed_date","rolling_days","calendar_month", name="expiry_type"), nullable=False)
duration_days = Column(Integer, default=30)
subs = relationship("Subscription", back_populates="plan")
class Subscription(Base):
__tablename__ = "subscriptions"
id = Column(Integer, primary_key=True)
plan_id = Column(Integer, ForeignKey("plans.id"), nullable=False)
user_id = Column(String(64), nullable=False)
invoice_hash = Column(String(256), unique=True, nullable=False)
invoice_request = Column(Text, nullable=False)
status = Column(Enum("pending","active","expired","cancelled", name="sub_status"), nullable=False, default="pending")
created_at = Column(DateTime, default=datetime.utcnow)
paid_at = Column(DateTime)
expires_at = Column(DateTime)
role_assigned_at = Column(DateTime)
role_removed_at = Column(DateTime)
raw_invoice_json = Column(Text)
raw_payment_json = Column(Text)
plan = relationship("Plan", back_populates="subs")
Base.metadata.create_all(engine) Base.metadata.create_all(engine)
logger.info("✅ Database tables ensured.") logger.info("✅ Database schema ready.")
# ─── Config helper ────────────────────────────────────────────────────────────
def get_cfg(key: str) -> str:
session = SessionLocal()
cfg = session.get(Config, key)
session.close()
return cfg.value if cfg else ""
# Read core settings from DB
DISCORD_TOKEN = get_cfg("discord_token")
GUILD_ID = int(get_cfg("guild_id") or 0)
LNBITS_URL = get_cfg("lnbits_url")
LNBITS_API_KEY = get_cfg("lnbits_api_key")
# Validate
for var_name, val in [
("discord_token", DISCORD_TOKEN),
("guild_id", GUILD_ID),
("lnbits_url", LNBITS_URL),
("lnbits_api_key", LNBITS_API_KEY),
]:
if not val:
logger.critical(f"Config key '{var_name}' is missing or empty in DB.")
exit(1)
# Build WebSocket URL
if LNBITS_URL.startswith("https://"):
ws_base = LNBITS_URL.replace("https://", "wss://", 1)
elif LNBITS_URL.startswith("http://"):
ws_base = LNBITS_URL.replace("http://", "ws://", 1)
else:
logger.critical("LNBITS_URL must start with http:// or https://")
exit(1)
LNBITS_WS_URL = f"{ws_base}/api/v1/ws/{LNBITS_API_KEY}"
# ─── Discord Bot Setup ──────────────────────────────────────────────────────── # ─── Discord Bot Setup ────────────────────────────────────────────────────────
intents = discord.Intents.default() intents = discord.Intents.default()
intents.members = True intents.members = True
bot = commands.Bot(command_prefix="!", intents=intents) bot = commands.Bot(command_prefix="!", intents=intents)
pending_invoices = {} # payment_hash -> (user_id, interaction)
# ─── Role Assignment Handler ────────────────────────────────────────────────── # ─── Utility: expiry computation ───────────────────────────────────────────────
async def assign_role_after_payment(payment_hash: str, payment_obj: dict): def compute_expiry(paid_at: datetime, plan: Plan) -> datetime:
logger.debug(f"ENTER assign_role_after_payment for hash={payment_hash}") if plan.expiry_type == "calendar_month":
data = pending_invoices.pop(payment_hash, None) first = paid_at.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
if data is None: return first + relativedelta(months=1)
logger.info(f"No pending invoice for hash={payment_hash}") elif plan.expiry_type == "rolling_days":
return return paid_at + relativedelta(days=plan.duration_days)
user_id, interaction = data else: # fixed_date
return paid_at # assume UI sets exact expires_at elsewhere
# Convert from millisatoshis to sats # ─── Handle a plan purchase ───────────────────────────────────────────────────
raw_msat = payment_obj.get("amount", 0) async def handle_plan_purchase(interaction: discord.Interaction, plan: Plan):
sat_amount = raw_msat // 1000 invoice_ch = bot.get_channel(int(plan.channel_id))
if not invoice_ch:
# Record payment in DB return await interaction.response.send_message(
session = SessionLocal() "❌ Invoice channel not found.", ephemeral=True
try:
payment = Payment(
payment_hash=payment_hash,
user_id=str(user_id),
amount=sat_amount,
paid_at=datetime.utcnow()
) )
session.add(payment)
session.commit()
logger.info(f"💾 Logged payment {payment_hash} for user {user_id} ({sat_amount} sats).")
except Exception:
session.rollback()
logger.exception("DB error saving payment")
finally:
session.close()
guild = bot.get_guild(GUILD_ID)
if not guild:
logger.error(f"Guild {GUILD_ID} not found")
return
member = guild.get_member(user_id) or await guild.fetch_member(user_id)
role = guild.get_role(ROLE_ID)
channel = bot.get_channel(CHANNEL_ID)
if not all([member, role, channel]):
logger.error("Member, role, or channel not found")
return
if role not in member.roles:
logger.debug(f"Adding role '{role.name}' to {member.display_name}")
try:
await asyncio.wait_for(
member.add_roles(role, reason="Paid via Lightning"),
timeout=10
)
await channel.send(
f"🎉 {member.mention} has paid **{sat_amount} sats** "
f"and received the **{role.name}** role!"
)
logger.info(f"✅ Role '{role.name}' assigned to {member.display_name}")
except Exception:
logger.exception("Error assigning role or sending confirmation")
else:
await channel.send(
f" {member.mention} paid again (hash={payment_hash}), "
f"role **{role.name}** already assigned."
)
logger.info(f"User {member.display_name} already had role; notified channel.")
# ─── LNbits WebSocket Listener ────────────────────────────────────────────────
async def lnbits_ws_listener():
await bot.wait_until_ready()
logger.info(f"👂 Connecting to LNbits WS at {LNBITS_WS_URL}")
while True:
try:
async with websockets.connect(LNBITS_WS_URL) as ws:
logger.info("✅ WS connected")
async for msg in ws:
logger.debug(f"WS ► {msg}")
try:
data = json.loads(msg)
pay = data.get("payment", {})
hsh = pay.get("checking_id") or pay.get("payment_hash")
amt = pay.get("amount", 0)
if hsh and pay.get("status") == "success" and amt > 0:
logger.info(f"💡 Payment received hash={hsh}, amt={amt}")
bot.loop.create_task(assign_role_after_payment(hsh, pay))
else:
logger.debug("Ignored WS update (not a paid invoice)")
except json.JSONDecodeError:
logger.warning("WS ► Non-JSON payload")
except Exception:
logger.exception("WS listener error; reconnecting in 10s")
await asyncio.sleep(10)
# ─── Slash Command Registration ───────────────────────────────────────────────
@bot.event
async def on_ready():
logger.info(f"Logged in as {bot.user} (ID: {bot.user.id})")
try:
synced = await bot.tree.sync()
logger.info(f"✅ Synced {len(synced)} command(s).")
except Exception:
logger.exception("Failed to sync commands")
bot.loop.create_task(lnbits_ws_listener())
@bot.tree.command(name=COMMAND_NAME, description="Pay to get your role via Lightning")
async def dynamic_command(interaction: discord.Interaction):
invoice_ch = bot.get_channel(CHANNEL_ID)
if invoice_ch is None:
await interaction.response.send_message(
"❌ Invoice channel not found. Check your config.",
ephemeral=True
)
return
loop = asyncio.get_running_loop()
invoice_data = { invoice_data = {
"out": False, "out": False,
"amount": PRICE, "amount": plan.price_sats,
"memo": f"Role for {interaction.user.display_name}" "memo": plan.invoice_message
} }
headers = { headers = {
"X-Api-Key": LNBITS_API_KEY, "X-Api-Key": LNBITS_API_KEY,
"Content-Type": "application/json" "Content-Type": "application/json"
} }
# Create invoice loop = asyncio.get_running_loop()
def create_invoice(): def create_invoice():
resp = requests.post( resp = requests.post(
f"{LNBITS_URL}/api/v1/payments", f"{LNBITS_URL}/api/v1/payments",
@ -220,50 +159,159 @@ async def dynamic_command(interaction: discord.Interaction):
return resp.json() return resp.json()
try: try:
invoice_json = await loop.run_in_executor(None, create_invoice) inv = await loop.run_in_executor(None, create_invoice)
except Exception: except Exception:
logger.exception("Error creating LNbits invoice") logger.exception("Invoice creation failed")
await interaction.response.send_message( return await interaction.response.send_message(
"❌ Failed to generate invoice. Try again later.", "❌ Failed to generate invoice. Try again later.",
ephemeral=True ephemeral=True
) )
return
payment_request = invoice_json["bolt11"] pay_req = inv["bolt11"]
payment_hash = invoice_json["payment_hash"] pay_hash = inv["payment_hash"]
pending_invoices[payment_hash] = (interaction.user.id, interaction)
logger.info(f"Stored pending invoice {payment_hash} for user {interaction.user.id}") # Store pending subscription
db = SessionLocal()
sub = Subscription(
plan_id = plan.id,
user_id = str(interaction.user.id),
invoice_hash = pay_hash,
invoice_request = pay_req,
raw_invoice_json= json.dumps(inv)
)
db.add(sub)
db.commit()
db.close()
logger.info(f"Stored pending subscription {pay_hash} for user {interaction.user.id}")
# Generate QR code # Generate QR code
def make_qr_buffer(): def make_qr_buf():
img = qrcode.make(payment_request) img = requests.compat.BytesIO()
buf = io.BytesIO() qrcode.make(pay_req).save(img, format="PNG")
img.save(buf, format="PNG") img.seek(0)
buf.seek(0) return img
return buf
qr_buffer = await loop.run_in_executor(None, make_qr_buffer) qr_buf = await loop.run_in_executor(None, make_qr_buf)
qr_file = File(qr_buffer, filename="invoice.png") qr_file = File(qr_buf, filename="invoice.png")
embed = Embed( embed = Embed(
title="⚡ Please Pay", title=f"⚡ Pay {plan.price_sats} sats for {plan.name}",
description=INVOICE_MESSAGE_TEMPLATE, description=plan.invoice_message,
color=discord.Color.gold() color=discord.Color.gold()
) )
embed.add_field(name="Invoice", value=f"```{payment_request}```", inline=False) embed.add_field(name="Invoice", value=f"```{pay_req}```", inline=False)
embed.add_field(name="Amount", value=f"{PRICE} sats", inline=True)
embed.set_image(url="attachment://invoice.png") embed.set_image(url="attachment://invoice.png")
embed.set_footer(text=f"Hash: {payment_hash}") embed.set_footer(text=f"Hash: {pay_hash}")
await invoice_ch.send( await invoice_ch.send(content=interaction.user.mention, embed=embed, file=qr_file)
content=interaction.user.mention,
embed=embed,
file=qr_file
)
await interaction.response.send_message( await interaction.response.send_message(
f"✅ Invoice posted in {invoice_ch.mention}", ephemeral=True f"✅ Invoice posted in {invoice_ch.mention}", ephemeral=True
) )
# ─── Run Bot ────────────────────────────────────────────────────────────────── # ─── WebSocket listener for payments ──────────────────────────────────────────
async def lnbits_ws_listener():
await bot.wait_until_ready()
logger.info(f"Listening to LNbits WS at {LNBITS_WS_URL}")
while True:
try:
async with websockets.connect(LNBITS_WS_URL) as ws:
async for msg in ws:
try:
data = json.loads(msg)
pay = data.get("payment", {})
hsh = pay.get("checking_id") or pay.get("payment_hash")
if not (hsh and pay.get("status") == "success"):
continue
# Activate subscription
db = SessionLocal()
sub = db.query(Subscription)\
.filter_by(invoice_hash=hsh, status="pending")\
.first()
if not sub:
db.close()
continue
paid_at = datetime.utcnow()
expires_at = compute_expiry(paid_at, sub.plan)
sub.status = "active"
sub.paid_at = paid_at
sub.expires_at = expires_at
sub.raw_payment_json = json.dumps(pay)
db.commit()
db.close()
# Assign role & announce
guild = bot.get_guild(GUILD_ID)
member = guild.get_member(int(sub.user_id)) or await guild.fetch_member(int(sub.user_id))
role = guild.get_role(int(sub.plan.role_id))
chan = bot.get_channel(int(sub.plan.channel_id))
await member.add_roles(role, reason="Subscription paid")
await chan.send(
f"🎉 {member.mention} paid **{sub.plan.price_sats} sats**"
f" for **{sub.plan.name}**! Expires {expires_at:%Y-%m-%d}."
)
except Exception:
logger.exception("Error processing WS message")
except Exception:
logger.exception("WS connection error, retry in 10s")
await asyncio.sleep(10)
# ─── Cleanup expired subscriptions ─────────────────────────────────────────────
async def cleanup_expired():
now = datetime.utcnow()
db = SessionLocal()
expired = db.query(Subscription)\
.filter(Subscription.status=="active", Subscription.expires_at <= now)\
.all()
for sub in expired:
guild = bot.get_guild(GUILD_ID)
member = guild.get_member(int(sub.user_id))
role = guild.get_role(int(sub.plan.role_id))
if member and role in member.roles:
try:
await member.remove_roles(role, reason="Subscription expired")
sub.status = "expired"
sub.role_removed_at = now
db.commit()
logger.info(f"Removed {role.name} from {member.display_name}")
except Exception:
logger.exception("Failed to remove expired role")
db.close()
# ─── on_ready: dynamic commands + start jobs ──────────────────────────────────
@bot.event
async def on_ready():
logger.info(f"Logged in as {bot.user} (ID: {bot.user.id})")
# Register commands for each plan
db = SessionLocal()
plans = db.query(Plan).all()
for plan in plans:
@bot.tree.command(name=plan.command_name, description=f"Buy {plan.name}")
async def _cmd(interaction: discord.Interaction, plan=plan):
await handle_plan_purchase(interaction, plan)
db.close()
await bot.tree.sync()
# Start WS listener
bot.loop.create_task(lnbits_ws_listener())
# Schedule cleanup daily at midnight UTC
scheduler = AsyncIOScheduler(timezone="UTC")
scheduler.add_job(cleanup_expired, "cron", hour=0, minute=0)
scheduler.start()
# ─── Entry point ───────────────────────────────────────────────────────────────
if __name__ == "__main__": if __name__ == "__main__":
# Seed Config table with required keys if missing
session = SessionLocal()
for key in ("discord_token","guild_id","lnbits_url","lnbits_api_key"):
if not session.get(Config, key):
session.add(Config(key=key, value=""))
session.commit()
session.close()
bot.run(DISCORD_TOKEN) bot.run(DISCORD_TOKEN)