Commit f7ac6a01 authored by Steven's avatar Steven

refactor: auth token refresh flow and simplify user hooks

parent 333c9df2
...@@ -12,6 +12,12 @@ const EXPIRES_KEY = "memos_token_expires_at"; ...@@ -12,6 +12,12 @@ const EXPIRES_KEY = "memos_token_expires_at";
// conflicting) refresh request of our own. // conflicting) refresh request of our own.
const TOKEN_CHANNEL_NAME = "memos_token_sync"; const TOKEN_CHANNEL_NAME = "memos_token_sync";
// Token refresh policy:
// - REQUEST_TOKEN_EXPIRY_BUFFER_MS: used for normal API requests.
// - FOCUS_TOKEN_EXPIRY_BUFFER_MS: used on tab visibility restore to refresh earlier.
export const REQUEST_TOKEN_EXPIRY_BUFFER_MS = 30 * 1000;
export const FOCUS_TOKEN_EXPIRY_BUFFER_MS = 2 * 60 * 1000;
interface TokenBroadcastMessage { interface TokenBroadcastMessage {
token: string; token: string;
expiresAt: string; // ISO string expiresAt: string; // ISO string
...@@ -91,11 +97,9 @@ export const setAccessToken = (token: string | null, expiresAt?: Date): void => ...@@ -91,11 +97,9 @@ export const setAccessToken = (token: string | null, expiresAt?: Date): void =>
} }
}; };
export const isTokenExpired = (bufferMs: number = 30000): boolean => { export const isTokenExpired = (bufferMs: number = REQUEST_TOKEN_EXPIRY_BUFFER_MS): boolean => {
if (!tokenExpiresAt) return true; if (!tokenExpiresAt) return true;
// Consider expired with a safety buffer before actual expiry // Consider expired with a safety buffer before actual expiry.
// Default: 30 seconds for regular requests
// Can use longer buffer (e.g., 2 minutes) for proactive refresh
return new Date() >= new Date(tokenExpiresAt.getTime() - bufferMs); return new Date() >= new Date(tokenExpiresAt.getTime() - bufferMs);
}; };
......
import { timestampDate } from "@bufbuild/protobuf/wkt"; import { timestampDate } from "@bufbuild/protobuf/wkt";
import { Code, ConnectError, createClient, type Interceptor } from "@connectrpc/connect"; import { Code, ConnectError, createClient, type Interceptor } from "@connectrpc/connect";
import { createConnectTransport } from "@connectrpc/connect-web"; import { createConnectTransport } from "@connectrpc/connect-web";
import { getAccessToken, setAccessToken } from "./auth-state"; import { getAccessToken, isTokenExpired, REQUEST_TOKEN_EXPIRY_BUFFER_MS, setAccessToken } from "./auth-state";
import { ActivityService } from "./types/proto/api/v1/activity_service_pb"; import { ActivityService } from "./types/proto/api/v1/activity_service_pb";
import { AttachmentService } from "./types/proto/api/v1/attachment_service_pb"; import { AttachmentService } from "./types/proto/api/v1/attachment_service_pb";
import { AuthService } from "./types/proto/api/v1/auth_service_pb"; import { AuthService } from "./types/proto/api/v1/auth_service_pb";
...@@ -12,6 +12,10 @@ import { ShortcutService } from "./types/proto/api/v1/shortcut_service_pb"; ...@@ -12,6 +12,10 @@ import { ShortcutService } from "./types/proto/api/v1/shortcut_service_pb";
import { UserService } from "./types/proto/api/v1/user_service_pb"; import { UserService } from "./types/proto/api/v1/user_service_pb";
import { redirectOnAuthFailure } from "./utils/auth-redirect"; import { redirectOnAuthFailure } from "./utils/auth-redirect";
interface RequestWithHeader {
header: Headers;
}
// ============================================================================ // ============================================================================
// Constants // Constants
// ============================================================================ // ============================================================================
...@@ -87,39 +91,77 @@ export async function refreshAccessToken(): Promise<void> { ...@@ -87,39 +91,77 @@ export async function refreshAccessToken(): Promise<void> {
} }
// ============================================================================ // ============================================================================
// Authentication Interceptor // Authentication Interceptor Helpers
// ============================================================================ // ============================================================================
const authInterceptor: Interceptor = (next) => async (req) => { function setAuthorizationHeader(req: RequestWithHeader, token: string | null) {
const token = getAccessToken(); if (!token) return;
if (token) {
req.header.set("Authorization", `Bearer ${token}`); req.header.set("Authorization", `Bearer ${token}`);
} }
try { function shouldHandleUnauthenticatedRetry(error: unknown, isRetryAttempt: boolean): boolean {
return await next(req);
} catch (error) {
if (!(error instanceof ConnectError)) { if (!(error instanceof ConnectError)) {
throw error; return false;
} }
if (error.code !== Code.Unauthenticated) { if (error.code !== Code.Unauthenticated) {
throw error; return false;
} }
if (isRetryAttempt) {
return false;
}
return true;
}
if (req.header.get(RETRY_HEADER) === RETRY_HEADER_VALUE) { async function refreshAndGetAccessToken(): Promise<string> {
throw error; await refreshAccessToken();
const token = getAccessToken();
if (!token) {
throw new ConnectError("Token refresh succeeded but no token available", Code.Internal);
} }
return token;
}
async function getRequestToken(): Promise<string | null> {
let token = getAccessToken();
if (!token) {
return null;
}
// Preflight refresh: avoid sending requests with expired access tokens.
// This is especially important for public endpoints (e.g. ListMemos), where
// an expired token could otherwise be treated as anonymous and return
// guest-scoped data before the reactive 401 refresh path runs.
if (isTokenExpired(REQUEST_TOKEN_EXPIRY_BUFFER_MS)) {
try { try {
await refreshAccessToken(); token = await refreshAndGetAccessToken();
} catch {
// Keep existing reactive 401 flow as fallback.
// Protected methods still trigger refresh/redirect in the catch block below.
}
}
const newToken = getAccessToken(); return token;
if (!newToken) { }
throw new ConnectError("Token refresh succeeded but no token available", Code.Internal);
// ============================================================================
// Authentication Interceptor
// ============================================================================
const authInterceptor: Interceptor = (next) => async (req) => {
const isRetryAttempt = req.header.get(RETRY_HEADER) === RETRY_HEADER_VALUE;
const token = await getRequestToken();
setAuthorizationHeader(req, token);
try {
return await next(req);
} catch (error) {
if (!shouldHandleUnauthenticatedRetry(error, isRetryAttempt)) {
throw error;
} }
req.header.set("Authorization", `Bearer ${newToken}`); try {
const newToken = await refreshAndGetAccessToken();
setAuthorizationHeader(req, newToken);
req.header.set(RETRY_HEADER, RETRY_HEADER_VALUE); req.header.set(RETRY_HEADER, RETRY_HEADER_VALUE);
return await next(req); return await next(req);
} catch (refreshError) { } catch (refreshError) {
......
import { useEffect } from "react"; import { useEffect } from "react";
import { getAccessToken, isTokenExpired } from "@/auth-state"; import { FOCUS_TOKEN_EXPIRY_BUFFER_MS, getAccessToken, isTokenExpired } from "@/auth-state";
/** /**
* Hook that proactively refreshes the access token when the tab becomes visible * Hook that proactively refreshes the access token when the tab becomes visible
...@@ -28,8 +28,7 @@ export function useTokenRefreshOnFocus(refreshFn: () => Promise<void>, enabled: ...@@ -28,8 +28,7 @@ export function useTokenRefreshOnFocus(refreshFn: () => Promise<void>, enabled:
// Check if token is expired or expiring soon (within 2 minutes) // Check if token is expired or expiring soon (within 2 minutes)
// Use a longer buffer than normal requests to be proactive // Use a longer buffer than normal requests to be proactive
const bufferMs = 2 * 60 * 1000; // 2 minutes if (isTokenExpired(FOCUS_TOKEN_EXPIRY_BUFFER_MS)) {
if (isTokenExpired(bufferMs)) {
try { try {
console.debug("[useTokenRefreshOnFocus] Token expired/expiring, refreshing before queries refetch"); console.debug("[useTokenRefreshOnFocus] Token expired/expiring, refreshing before queries refetch");
await refreshFn(); await refreshFn();
......
import { create } from "@bufbuild/protobuf"; import { create } from "@bufbuild/protobuf";
import { FieldMaskSchema } from "@bufbuild/protobuf/wkt"; import { FieldMaskSchema } from "@bufbuild/protobuf/wkt";
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
import { authServiceClient, shortcutServiceClient, userServiceClient } from "@/connect"; import { shortcutServiceClient, userServiceClient } from "@/connect";
import { buildUserSettingName } from "@/helpers/resource-names"; import { buildUserSettingName } from "@/helpers/resource-names";
import useCurrentUser from "@/hooks/useCurrentUser";
import { User, UserSetting, UserSetting_GeneralSetting, UserSetting_Key, UserSettingSchema } from "@/types/proto/api/v1/user_service_pb"; import { User, UserSetting, UserSetting_GeneralSetting, UserSetting_Key, UserSettingSchema } from "@/types/proto/api/v1/user_service_pb";
// Query keys factory // Query keys factory
...@@ -18,20 +19,6 @@ export const userKeys = { ...@@ -18,20 +19,6 @@ export const userKeys = {
byNames: (names: string[]) => [...userKeys.all, "byNames", ...names.sort()] as const, byNames: (names: string[]) => [...userKeys.all, "byNames", ...names.sort()] as const,
}; };
// NOTE: This hook is currently UNUSED in favor of the AuthContext-based
// useCurrentUser hook (src/hooks/useCurrentUser.ts). This is kept for potential
// future migration to React Query for auth state.
export function useCurrentUserQuery() {
return useQuery({
queryKey: userKeys.currentUser(),
queryFn: async () => {
const { user } = await authServiceClient.getCurrentUser({});
return user;
},
staleTime: 1000 * 60 * 5, // 5 minutes - auth doesn't change often
});
}
export function useUser(name: string, options?: { enabled?: boolean }) { export function useUser(name: string, options?: { enabled?: boolean }) {
return useQuery({ return useQuery({
queryKey: userKeys.detail(name), queryKey: userKeys.detail(name),
...@@ -69,7 +56,7 @@ export function useShortcuts() { ...@@ -69,7 +56,7 @@ export function useShortcuts() {
} }
export function useNotifications() { export function useNotifications() {
const { data: currentUser } = useCurrentUserQuery(); const currentUser = useCurrentUser();
return useQuery({ return useQuery({
queryKey: userKeys.notifications(), queryKey: userKeys.notifications(),
...@@ -86,7 +73,7 @@ export function useNotifications() { ...@@ -86,7 +73,7 @@ export function useNotifications() {
} }
export function useTagCounts(forCurrentUser = false) { export function useTagCounts(forCurrentUser = false) {
const { data: currentUser } = useCurrentUserQuery(); const currentUser = useCurrentUser();
return useQuery({ return useQuery({
queryKey: forCurrentUser ? [...userKeys.stats(), "tagCounts", "current"] : [...userKeys.stats(), "tagCounts", "all"], queryKey: forCurrentUser ? [...userKeys.stats(), "tagCounts", "current"] : [...userKeys.stats(), "tagCounts", "all"],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment